Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 0 additions & 16 deletions ax/core/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,6 @@ class Metric(SortableBase, SerializationMixin):
properties: Properties specific to a particular metric.
"""

# The set of exception types stored in a ``MetchFetchE.exception`` that are
# recoverable ``orchestrator._fetch_and_process_trials_data_results()``.
# Exception may be a subclass of any of these types. If you want your metric
# to never fail the trial, set this to ``{Exception}`` in your metric subclass.
recoverable_exceptions: set[type[Exception]] = set()
has_map_data: bool = False

def __init__(
Expand Down Expand Up @@ -164,17 +159,6 @@ def period_of_new_data_after_trial_completion(cls) -> timedelta:
"""
return timedelta(0)

@classmethod
def is_recoverable_fetch_e(cls, metric_fetch_e: MetricFetchE) -> bool:
"""Checks whether the given MetricFetchE is recoverable for this metric class
in ``orchestrator._fetch_and_process_trials_data_results``.
"""
if metric_fetch_e.exception is None:
return False
return any(
isinstance(metric_fetch_e.exception, e) for e in cls.recoverable_exceptions
)

# NOTE: This is rarely overridden –– oonly if you want to fetch data in groups
# consisting of multiple different metric classes, for data to be fetched together.
# This makes sense only if `fetch_trial data_multi` or `fetch_experiment_data_multi`
Expand Down
283 changes: 154 additions & 129 deletions ax/orchestration/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from ax.core.runner import Runner
from ax.core.trial import Trial
from ax.core.trial_status import TrialStatus
from ax.core.utils import compute_metric_availability, MetricAvailability
from ax.exceptions.core import (
AxError,
DataRequiredError,
Expand Down Expand Up @@ -80,13 +81,6 @@
"of an optimization and if at least {min_failed} trials have been "
"failed/abandoned, potentially automatically due to issues with the trial."
)
METRIC_FETCH_ERR_MESSAGE = (
"A majority of the trial failures encountered are due to metric fetching errors. "
"This could mean the metrics are flaky, broken, or misconfigured. Please check "
"that the trial processes/jobs are successfully producing the expected metrics and "
"that the metric is correctly configured."
)

EXPECTED_STAGED_MSG = (
"Expected all trials to be in status {expected} after running or staging, "
"found {t_idx_to_status}."
Expand Down Expand Up @@ -191,13 +185,6 @@ class Orchestrator(WithDBSettingsBase, BestPointMixin):
# Saved as a property so that it can be accessed after optimization is complex (ex.
# for global stopping saving calculation).
_num_remaining_requested_trials: int = 0
# Total number of MetricFetchEs encountered during the course of optimization. Note
# this is different from and may be greater than the number of trials that have
# been marked either FAILED or ABANDONED due to metric fetching errors.
_num_metric_fetch_e_encountered: int = 0
# Number of trials that have been marked either FAILED or ABANDONED due to
# MetricFetchE being encountered during _fetch_and_process_trials_data_results
_num_trials_bad_due_to_err: int = 0
# Keeps track of whether the allowed failure rate has been exceeded during
# the optimization. If true, allows any pending trials to finish and raises
# an error through self._complete_optimization.
Expand Down Expand Up @@ -1073,83 +1060,63 @@ def summarize_final_result(self) -> OptimizationResult:
"""
return OptimizationResult()

def _check_if_failure_rate_exceeded(self, force_check: bool = False) -> bool:
"""Checks if the failure rate (set in Orchestrator options) has been exceeded at
any point during the optimization.
def error_if_failure_rate_exceeded(self, force_check: bool = False) -> None:
"""Raises an exception if the failure rate (set in Orchestrator options) has
been exceeded at any point during the optimization.

NOTE: Both FAILED and ABANDONED trial statuses count towards the failure rate.
The failure rate is computed as the ratio of "bad" trials to total trials
created by this orchestrator. "Bad" trials include:
- Execution failures: trials with FAILED or ABANDONED status.
- Metric-incomplete trials: COMPLETED trials whose metric data is not
fully available (as determined by ``compute_metric_availability``).

Args:
force_check: Indicates whether to force a failure-rate check
regardless of the number of trials that have been executed. If False
(default), the check will be skipped if the optimization has fewer than
five failed trials. If True, the check will be performed unless there
are 0 failures.
``min_failed_trials_for_failure_rate_check`` bad trials. If True, the
check will be performed unless there are 0 bad trials.
"""
# Count runner-level failures (FAILED + ABANDONED).
num_execution_failures = self._num_bad_in_orchestrator()

Effect on state:
If the failure rate has been exceeded, a warning is logged and the private
attribute `_failure_rate_has_been_exceeded` is set to True, which causes the
`_get_max_pending_trials` to return zero, so that no further trials are
scheduled and an error is raised at the end of the optimization.
# Count completed trials with incomplete metric availability.
num_metric_incomplete, missing_metrics_by_trial = (
self._get_metric_incomplete_trials()
)

Returns:
Boolean representing whether the failure rate has been exceeded.
"""
if self._failure_rate_has_been_exceeded:
return True
num_bad = num_execution_failures + num_metric_incomplete

num_bad_in_orchestrator = self._num_bad_in_orchestrator()
# skip check if 0 failures
if num_bad_in_orchestrator == 0:
return False
if not self._failure_rate_has_been_exceeded:
# Skip check if 0 bad trials.
if num_bad == 0:
return

# skip check if fewer than min_failed_trials_for_failure_rate_check failures
# unless force_check is True
if (
num_bad_in_orchestrator
< self.options.min_failed_trials_for_failure_rate_check
and not force_check
):
return False
# Skip check if fewer than min threshold unless force_check.
if (
num_bad < self.options.min_failed_trials_for_failure_rate_check
and not force_check
):
return

num_ran_in_orchestrator = self._num_ran_in_orchestrator()
failure_rate_exceeded = (
num_bad_in_orchestrator / num_ran_in_orchestrator
) > self.options.tolerated_trial_failure_rate
num_ran_in_orchestrator = self._num_ran_in_orchestrator()
failure_rate_exceeded = (
num_bad / num_ran_in_orchestrator
) > self.options.tolerated_trial_failure_rate

if not failure_rate_exceeded:
return

if failure_rate_exceeded:
if self._num_trials_bad_due_to_err > num_bad_in_orchestrator / 2:
self.logger.warning(
"MetricFetchE INFO: Sweep aborted due to an exceeded error rate, "
"which was primarily caused by failure to fetch metrics. Please "
"check if anything could cause your metrics to be flaky or "
"broken."
)
# NOTE: this private attribute causes `_get_max_pending_trials` to
# return zero, which causes no further trials to be scheduled.
self._failure_rate_has_been_exceeded = True
return True

return False

def error_if_failure_rate_exceeded(self, force_check: bool = False) -> None:
"""Raises an exception if the failure rate (set in Orchestrator options) has
been exceeded at any point during the optimization.

NOTE: Both FAILED and ABANDONED trial statuses count towards the failure rate.

Args:
force_check: Indicates whether to force a failure-rate check
regardless of the number of trials that have been executed. If False
(default), the check will be skipped if the optimization has fewer than
five failed trials. If True, the check will be performed unless there
are 0 failures.
"""
if self._check_if_failure_rate_exceeded(force_check=force_check):
raise self._get_failure_rate_exceeded_error(
num_bad_in_orchestrator=self._num_bad_in_orchestrator(),
num_ran_in_orchestrator=self._num_ran_in_orchestrator(),
)
raise self._get_failure_rate_exceeded_error(
num_execution_failures=num_execution_failures,
num_metric_incomplete=num_metric_incomplete,
num_ran_in_orchestrator=self._num_ran_in_orchestrator(),
missing_metrics_by_trial=missing_metrics_by_trial,
)

def _error_if_status_quo_infeasible(self) -> None:
"""Raises an exception if the status-quo arm is infeasible and the
Expand Down Expand Up @@ -2032,9 +1999,13 @@ def _fetch_and_process_trials_data_results(
self,
trial_indices: Iterable[int],
) -> dict[int, dict[str, MetricFetchResult]]:
"""
Fetches results from experiment and modifies trial statuses depending on
success or failure.
"""Fetch trial data results and log any metric fetch errors.

Metric fetch errors are logged but do NOT change trial status.
``MetricAvailability`` (computed via ``compute_metric_availability``)
tracks data completeness separately, and the failure rate check in
``error_if_failure_rate_exceeded`` uses it to detect persistent
metric issues.
"""

try:
Expand Down Expand Up @@ -2085,41 +2056,12 @@ def _fetch_and_process_trials_data_results(
f"Failed to fetch {metric_name} for trial {trial_index} with "
f"status {status}, found {metric_fetch_e}."
)
self._num_metric_fetch_e_encountered += 1
self._report_metric_fetch_e(
trial=self.experiment.trials[trial_index],
metric_name=metric_name,
metric_fetch_e=metric_fetch_e,
)

# If the fetch failure was for a metric in the optimization config (an
# objective or constraint) mark the trial as failed
optimization_config = self.experiment.optimization_config
if (
optimization_config is not None
and metric_name in optimization_config.metric_names
and not self.experiment.metrics[metric_name].is_recoverable_fetch_e(
metric_fetch_e=metric_fetch_e
)
):
status = self._mark_err_trial_status(
trial=self.experiment.trials[trial_index],
metric_name=metric_name,
metric_fetch_e=metric_fetch_e,
)
self.logger.warning(
f"MetricFetchE INFO: Because {metric_name} is an objective, "
f"marking trial {trial_index} as {status}."
)
self._num_trials_bad_due_to_err += 1
continue

self.logger.info(
"MetricFetchE INFO: Continuing optimization even though "
"MetricFetchE encountered."
)
continue

return results

def _report_metric_fetch_e(
Expand All @@ -2128,39 +2070,122 @@ def _report_metric_fetch_e(
metric_name: str,
metric_fetch_e: MetricFetchE,
) -> None:
"""Hook for subclasses to react to metric fetch errors.

Called once per metric fetch error during
``_fetch_and_process_trials_data_results``. The default
implementation is a no-op; override in subclasses to add custom
reporting (e.g., creating error tables or pastes).
"""
pass

def _mark_err_trial_status(
def _get_metric_incomplete_trials(
self,
trial: BaseTrial,
metric_name: str | None = None,
metric_fetch_e: MetricFetchE | None = None,
) -> TrialStatus:
trial.mark_abandoned(
reason=metric_fetch_e.message if metric_fetch_e else None, unsafe=True
) -> tuple[int, dict[int, set[str]]]:
"""Count completed trials with incomplete metric availability and identify
which metrics are missing for each.

Required metrics include optimization config metrics and any explicitly
defined early stopping strategy metrics.

Returns:
A tuple of (num_metric_incomplete, missing_metrics_by_trial) where
missing_metrics_by_trial maps trial index to the set of missing
metric names.
"""
opt_config = self.experiment.optimization_config
if opt_config is None:
return 0, {}

completed_trial_indices = [
t.index
for t in self.experiment.trials.values()
if t.status == TrialStatus.COMPLETED
and t.index >= self._num_preexisting_trials
]
if len(completed_trial_indices) == 0:
return 0, {}

required_metrics = set(opt_config.metric_names)

# Include explicitly defined early stopping strategy metrics.
# ESS stores metric *signatures*, which may differ from metric names,
# so we resolve them via experiment.signature_to_metric.
ess = self.options.early_stopping_strategy
ess_signatures = ess.metric_signatures if ess is not None else None
if ess_signatures is not None:
for sig in ess_signatures:
metric = self.experiment.signature_to_metric[sig]
required_metrics.add(metric.name)

metric_availabilities = compute_metric_availability(
experiment=self.experiment,
trial_indices=completed_trial_indices,
metric_names=required_metrics,
)
return TrialStatus.ABANDONED

# Identify which specific metrics are missing per trial.
data = self.experiment.lookup_data(trial_indices=completed_trial_indices)
metrics_per_trial: dict[int, set[str]] = {}
if len(data.metric_names) > 0:
df = data.full_df
for trial_idx, group in df.groupby("trial_index")["metric_name"]:
metrics_per_trial[int(trial_idx)] = set(group.unique())

missing_metrics_by_trial: dict[int, set[str]] = {}
for idx, avail in metric_availabilities.items():
if avail != MetricAvailability.COMPLETE:
available = metrics_per_trial.get(idx, set())
missing_metrics_by_trial[idx] = required_metrics - available

return len(missing_metrics_by_trial), missing_metrics_by_trial

def _get_failure_rate_exceeded_error(
self,
num_bad_in_orchestrator: int,
num_execution_failures: int,
num_metric_incomplete: int,
num_ran_in_orchestrator: int,
missing_metrics_by_trial: dict[int, set[str]],
) -> FailureRateExceededError:
return FailureRateExceededError(
(
f"{METRIC_FETCH_ERR_MESSAGE}\n"
if self._num_trials_bad_due_to_err > num_bad_in_orchestrator / 2
else ""
"""Build an actionable error message describing why the failure rate was
exceeded, including runner failures, metric-incomplete trials, which
metrics are missing, and which trials are affected.
"""
num_bad = num_execution_failures + num_metric_incomplete
observed_rate = num_bad / num_ran_in_orchestrator

parts: list[str] = []
parts.append(
f"Failure rate exceeded: {num_bad} of {num_ran_in_orchestrator} "
f"trials were unsuccessful (observed rate: {observed_rate:.0%}, tolerance: "
f"{self.options.tolerated_trial_failure_rate:.0%}). "
f"Checks are triggered when at least "
f"{self.options.min_failed_trials_for_failure_rate_check} trials "
"are unsuccessful or at the end of the optimization."
)

if num_execution_failures > 0:
parts.append(
f"{num_execution_failures} trial(s) failed at the execution "
"level (FAILED or ABANDONED). Check any trial evaluation "
"processes/jobs to see why they are failing."
)
+ " Orignal error message: "
+ FAILURE_EXCEEDED_MSG.format(
f_rate=self.options.tolerated_trial_failure_rate,
n_failed=num_bad_in_orchestrator,
n_ran=num_ran_in_orchestrator,
min_failed=self.options.min_failed_trials_for_failure_rate_check,
observed_rate=float(num_bad_in_orchestrator) / num_ran_in_orchestrator,

if num_metric_incomplete > 0:
all_missing: set[str] = set()
for missing in missing_metrics_by_trial.values():
all_missing.update(missing)
affected_trials = sorted(missing_metrics_by_trial.keys())

parts.append(
f"{num_metric_incomplete} trial(s) have incomplete metric data. "
f"Missing metrics: {sorted(all_missing)}. "
f"Affected trials: {affected_trials}. "
"Check that your metric fetching infrastructure is healthy "
"and that the metrics are being logged correctly."
)
)

return FailureRateExceededError("\n".join(parts))

def _warn_if_non_terminal_trials(self) -> None:
"""Warns if there are any non-terminal trials on the experiment."""
Expand Down
Loading
Loading