Skip to content

octopus.metrics

Init metrics.

MLType

Bases: StrEnum

Machine learning task types.

Source code in octopus/types.py
class MLType(StrEnum):
    """Machine learning task types."""

    BINARY = "binary"
    MULTICLASS = "multiclass"
    REGRESSION = "regression"
    TIMETOEVENT = "timetoevent"

Metric

Metric instance.

Represents a metric with its configuration and calculation methods.

Source code in octopus/metrics/config.py
@define
class Metric:
    """Metric instance.

    Represents a metric with its configuration and calculation methods.
    """

    name: str
    metric_function: MetricFunction = field(validator=validators.is_callable())
    ml_types: frozenset[MLType] = field(converter=to_ml_types_frozenset, validator=validate_ml_types)
    higher_is_better: bool = field(validator=validators.instance_of(bool))
    prediction_type: PredictionType = field(converter=PredictionType, validator=validators.in_(list(PredictionType)))
    scorer_string: str = field(validator=validators.instance_of(str))  # needed for some sklearn functionalities
    metric_params: dict[str, Any] = field(factory=dict)

    def supports_ml_type(self, ml_type: MLType) -> bool:
        """Check if this metric supports the given ml_type."""
        return ml_type in self.ml_types

    @property
    def direction(self) -> MetricDirection:
        """Optimization direction for Optuna ('maximize' or 'minimize')."""
        return MetricDirection.MAXIMIZE if self.higher_is_better else MetricDirection.MINIMIZE

    def calculate(self, y_true: OctoArrayLike, y_pred: OctoArrayLike, **kwargs) -> float:
        """Calculate metric for classification/regression tasks.

        Args:
            y_true: True target values
            y_pred: Predicted values (predictions or probabilities depending on prediction_type)
            **kwargs: Additional keyword arguments passed to metric function

        Returns:
            Metric value as float

        Raises:
            ValueError: If called on a time-to-event metric
        """
        if self.supports_ml_type(MLType.TIMETOEVENT):
            raise ValueError(
                f"Metric '{self.name}' is a time-to-event metric. "
                "Use calculate_t2e(event_indicator, event_time, estimate) instead."
            )
        return float(self.metric_function(y_true, y_pred, **self.metric_params))

    def calculate_t2e(
        self, event_indicator: OctoArrayLike, event_time: OctoArrayLike, estimate: OctoArrayLike, **kwargs
    ) -> float:
        """Calculate metric for time-to-event tasks.

        Args:
            event_indicator: Boolean array indicating whether event occurred
            event_time: Array of event/censoring times
            estimate: Predicted risk/survival estimates from model
            **kwargs: Additional keyword arguments passed to metric function

        Returns:
            Metric value as float

        Raises:
            ValueError: If called on a non-time-to-event metric
        """
        if not self.supports_ml_type(MLType.TIMETOEVENT):
            raise ValueError(
                f"Metric '{self.name}' is not a time-to-event metric. Use calculate(y_true, y_pred) instead."
            )

        # Merge metric_params with any additional kwargs
        params = {**self.metric_params, **kwargs}
        result = self.metric_function(event_indicator, event_time, estimate, **params)

        # Handle tuple return (some T2E metrics return tuple)
        return float(result[0] if isinstance(result, tuple) else result)

direction property

Optimization direction for Optuna ('maximize' or 'minimize').

calculate(y_true, y_pred, **kwargs)

Calculate metric for classification/regression tasks.

Parameters:

Name Type Description Default
y_true OctoArrayLike

True target values

required
y_pred OctoArrayLike

Predicted values (predictions or probabilities depending on prediction_type)

required
**kwargs

Additional keyword arguments passed to metric function

{}

Returns:

Type Description
float

Metric value as float

Raises:

Type Description
ValueError

If called on a time-to-event metric

Source code in octopus/metrics/config.py
def calculate(self, y_true: OctoArrayLike, y_pred: OctoArrayLike, **kwargs) -> float:
    """Calculate metric for classification/regression tasks.

    Args:
        y_true: True target values
        y_pred: Predicted values (predictions or probabilities depending on prediction_type)
        **kwargs: Additional keyword arguments passed to metric function

    Returns:
        Metric value as float

    Raises:
        ValueError: If called on a time-to-event metric
    """
    if self.supports_ml_type(MLType.TIMETOEVENT):
        raise ValueError(
            f"Metric '{self.name}' is a time-to-event metric. "
            "Use calculate_t2e(event_indicator, event_time, estimate) instead."
        )
    return float(self.metric_function(y_true, y_pred, **self.metric_params))

calculate_t2e(event_indicator, event_time, estimate, **kwargs)

Calculate metric for time-to-event tasks.

Parameters:

Name Type Description Default
event_indicator OctoArrayLike

Boolean array indicating whether event occurred

required
event_time OctoArrayLike

Array of event/censoring times

required
estimate OctoArrayLike

Predicted risk/survival estimates from model

required
**kwargs

Additional keyword arguments passed to metric function

{}

Returns:

Type Description
float

Metric value as float

Raises:

Type Description
ValueError

If called on a non-time-to-event metric

Source code in octopus/metrics/config.py
def calculate_t2e(
    self, event_indicator: OctoArrayLike, event_time: OctoArrayLike, estimate: OctoArrayLike, **kwargs
) -> float:
    """Calculate metric for time-to-event tasks.

    Args:
        event_indicator: Boolean array indicating whether event occurred
        event_time: Array of event/censoring times
        estimate: Predicted risk/survival estimates from model
        **kwargs: Additional keyword arguments passed to metric function

    Returns:
        Metric value as float

    Raises:
        ValueError: If called on a non-time-to-event metric
    """
    if not self.supports_ml_type(MLType.TIMETOEVENT):
        raise ValueError(
            f"Metric '{self.name}' is not a time-to-event metric. Use calculate(y_true, y_pred) instead."
        )

    # Merge metric_params with any additional kwargs
    params = {**self.metric_params, **kwargs}
    result = self.metric_function(event_indicator, event_time, estimate, **params)

    # Handle tuple return (some T2E metrics return tuple)
    return float(result[0] if isinstance(result, tuple) else result)

supports_ml_type(ml_type)

Check if this metric supports the given ml_type.

Source code in octopus/metrics/config.py
def supports_ml_type(self, ml_type: MLType) -> bool:
    """Check if this metric supports the given ml_type."""
    return ml_type in self.ml_types

Metrics

Central registry for metrics.

Usage

Get metric instance

metric = Metrics.get_instance("AUCROC") metric.calculate(y_true, y_pred)

Get direction

direction = Metrics.get_direction("AUCROC")

Source code in octopus/metrics/core.py
class Metrics:
    """Central registry for metrics.

    Usage:
        # Get metric instance
        metric = Metrics.get_instance("AUCROC")
        metric.calculate(y_true, y_pred)

        # Get direction
        direction = Metrics.get_direction("AUCROC")
    """

    # Internal registry: metric name -> function returning Metric
    _config_factories: ClassVar[dict[str, Callable[[], Metric]]] = {}

    # Internal cache: metric name -> Metric
    _metric_configs: ClassVar[dict[str, Metric]] = {}

    @classmethod
    def get_all_metrics(cls) -> dict[str, Callable[[], Metric]]:
        """Get all registered metric factory functions.

        Returns:
            Dictionary mapping metric names to their factory functions.
        """
        return cls._config_factories

    @classmethod
    def register(cls, name: str) -> Callable[[Callable[[], Metric]], Callable[[], Metric]]:
        """Register a metric factory function under a given name.

        Args:
            name: The name to register the metric under.

        Returns:
            Decorator function.
        """

        def decorator(factory: Callable[[], Metric]) -> Callable[[], Metric]:
            if name in cls._config_factories:
                raise ValueError(f"Metric '{name}' is already registered.")
            cls._config_factories[name] = factory
            return factory

        return decorator

    @classmethod
    def get_instance(cls, name: str) -> Metric:
        """Get metric instance by name.

        This is the primary method for getting a metric to use for calculation.
        Returns a Metric instance that has calculate() and calculate_t2e() methods.

        Args:
            name: The name of the metric to retrieve.

        Returns:
            Metric instance with calculate methods.

        Raises:
            UnknownMetricError: If no metric with the specified name is found.

        Usage:
            metric = Metrics.get_instance("AUCROC")
            value = metric.calculate(y_true, y_pred)
        """
        # Return cached config if available
        if name in cls._metric_configs:
            return cls._metric_configs[name]

        # Lookup factory
        factory = cls._config_factories.get(name)
        if factory is None:
            available = ", ".join(sorted(cls._config_factories.keys()))
            raise UnknownMetricError(
                f"Unknown metric '{name}'. Available metrics are: {available}. Please check the metric name and try again."
            )

        # Build config via factory and enforce name consistency
        config = factory()
        object.__setattr__(config, "name", name)
        cls._metric_configs[name] = config
        return config

    @classmethod
    def get_direction(cls, name: str) -> MetricDirection:
        """Get the optuna direction by name.

        Args:
            name: The name of the metric.

        Returns:
            MetricDirection.MAXIMIZE if higher_is_better is True, else MetricDirection.MINIMIZE.
        """
        return MetricDirection.MAXIMIZE if cls.get_instance(name).higher_is_better else MetricDirection.MINIMIZE

    @classmethod
    def get_by_type(cls, *ml_types: MLType) -> list[str]:
        """Get list of metric names for specified ML types.

        Args:
            *ml_types: One or more MLType enums (e.g., MLType.REGRESSION, MLType.BINARY, MLType.MULTICLASS).

        Returns:
            List of metric names matching any of the specified ML types.

        Example:
            >>> Metrics.get_by_type(MLType.REGRESSION)
            ['RMSE', 'MAE', 'R2', ...]
            >>> Metrics.get_by_type(MLType.BINARY, MLType.MULTICLASS)
            ['AUCROC', 'ACC', ...]
        """
        requested = frozenset(ml_types)
        matching_metrics = []
        for name, factory in cls._config_factories.items():
            metric = factory()
            if metric.ml_types & requested:  # intersection
                matching_metrics.append(name)
        return sorted(matching_metrics)

get_all_metrics() classmethod

Get all registered metric factory functions.

Returns:

Type Description
dict[str, Callable[[], Metric]]

Dictionary mapping metric names to their factory functions.

Source code in octopus/metrics/core.py
@classmethod
def get_all_metrics(cls) -> dict[str, Callable[[], Metric]]:
    """Get all registered metric factory functions.

    Returns:
        Dictionary mapping metric names to their factory functions.
    """
    return cls._config_factories

get_by_type(*ml_types) classmethod

Get list of metric names for specified ML types.

Parameters:

Name Type Description Default
*ml_types MLType

One or more MLType enums (e.g., MLType.REGRESSION, MLType.BINARY, MLType.MULTICLASS).

()

Returns:

Type Description
list[str]

List of metric names matching any of the specified ML types.

Example

Metrics.get_by_type(MLType.REGRESSION) ['RMSE', 'MAE', 'R2', ...] Metrics.get_by_type(MLType.BINARY, MLType.MULTICLASS) ['AUCROC', 'ACC', ...]

Source code in octopus/metrics/core.py
@classmethod
def get_by_type(cls, *ml_types: MLType) -> list[str]:
    """Get list of metric names for specified ML types.

    Args:
        *ml_types: One or more MLType enums (e.g., MLType.REGRESSION, MLType.BINARY, MLType.MULTICLASS).

    Returns:
        List of metric names matching any of the specified ML types.

    Example:
        >>> Metrics.get_by_type(MLType.REGRESSION)
        ['RMSE', 'MAE', 'R2', ...]
        >>> Metrics.get_by_type(MLType.BINARY, MLType.MULTICLASS)
        ['AUCROC', 'ACC', ...]
    """
    requested = frozenset(ml_types)
    matching_metrics = []
    for name, factory in cls._config_factories.items():
        metric = factory()
        if metric.ml_types & requested:  # intersection
            matching_metrics.append(name)
    return sorted(matching_metrics)

get_direction(name) classmethod

Get the optuna direction by name.

Parameters:

Name Type Description Default
name str

The name of the metric.

required

Returns:

Type Description
MetricDirection

MetricDirection.MAXIMIZE if higher_is_better is True, else MetricDirection.MINIMIZE.

Source code in octopus/metrics/core.py
@classmethod
def get_direction(cls, name: str) -> MetricDirection:
    """Get the optuna direction by name.

    Args:
        name: The name of the metric.

    Returns:
        MetricDirection.MAXIMIZE if higher_is_better is True, else MetricDirection.MINIMIZE.
    """
    return MetricDirection.MAXIMIZE if cls.get_instance(name).higher_is_better else MetricDirection.MINIMIZE

get_instance(name) classmethod

Get metric instance by name.

This is the primary method for getting a metric to use for calculation. Returns a Metric instance that has calculate() and calculate_t2e() methods.

Parameters:

Name Type Description Default
name str

The name of the metric to retrieve.

required

Returns:

Type Description
Metric

Metric instance with calculate methods.

Raises:

Type Description
UnknownMetricError

If no metric with the specified name is found.

Usage

metric = Metrics.get_instance("AUCROC") value = metric.calculate(y_true, y_pred)

Source code in octopus/metrics/core.py
@classmethod
def get_instance(cls, name: str) -> Metric:
    """Get metric instance by name.

    This is the primary method for getting a metric to use for calculation.
    Returns a Metric instance that has calculate() and calculate_t2e() methods.

    Args:
        name: The name of the metric to retrieve.

    Returns:
        Metric instance with calculate methods.

    Raises:
        UnknownMetricError: If no metric with the specified name is found.

    Usage:
        metric = Metrics.get_instance("AUCROC")
        value = metric.calculate(y_true, y_pred)
    """
    # Return cached config if available
    if name in cls._metric_configs:
        return cls._metric_configs[name]

    # Lookup factory
    factory = cls._config_factories.get(name)
    if factory is None:
        available = ", ".join(sorted(cls._config_factories.keys()))
        raise UnknownMetricError(
            f"Unknown metric '{name}'. Available metrics are: {available}. Please check the metric name and try again."
        )

    # Build config via factory and enforce name consistency
    config = factory()
    object.__setattr__(config, "name", name)
    cls._metric_configs[name] = config
    return config

register(name) classmethod

Register a metric factory function under a given name.

Parameters:

Name Type Description Default
name str

The name to register the metric under.

required

Returns:

Type Description
Callable[[Callable[[], Metric]], Callable[[], Metric]]

Decorator function.

Source code in octopus/metrics/core.py
@classmethod
def register(cls, name: str) -> Callable[[Callable[[], Metric]], Callable[[], Metric]]:
    """Register a metric factory function under a given name.

    Args:
        name: The name to register the metric under.

    Returns:
        Decorator function.
    """

    def decorator(factory: Callable[[], Metric]) -> Callable[[], Metric]:
        if name in cls._config_factories:
            raise ValueError(f"Metric '{name}' is already registered.")
        cls._config_factories[name] = factory
        return factory

    return decorator

PredictionType

Bases: StrEnum

The format in which a metric expects its predictions.

  • PREDICTIONS: the metric receives hard class labels (0/1 for binary, integer class indices for multiclass, continuous values for regression). For binary classification, these are derived by thresholding predict_proba output — predict is not called directly.
  • PROBABILITIES: the metric receives raw probability scores or continuous outputs directly from predict_proba.

Used in: - Metric.prediction_type: declared per metric in the registry - metrics/utils.py: used to prepare the correct input before calling the metric function - EfsModule: selects the correct column from the CV predictions table

Source code in octopus/types.py
class PredictionType(StrEnum):
    """The format in which a metric expects its predictions.

    - ``PREDICTIONS``: the metric receives hard class labels (0/1 for binary,
      integer class indices for multiclass, continuous values for regression).
      For binary classification, these are derived by thresholding
      ``predict_proba`` output — ``predict`` is not called directly.
    - ``PROBABILITIES``: the metric receives raw probability scores or
      continuous outputs directly from ``predict_proba``.

    Used in:
    - ``Metric.prediction_type``: declared per metric in the registry
    - ``metrics/utils.py``: used to prepare the correct input before calling
      the metric function
    - ``EfsModule``: selects the correct column from the CV predictions table
    """

    PREDICTIONS = "predictions"
    PROBABILITIES = "probabilities"

acc_metric()

Accuracy metric configuration.

Source code in octopus/metrics/classification.py
@Metrics.register("ACC")
def acc_metric() -> Metric:
    """Accuracy metric configuration."""
    return Metric(
        name="ACC",
        metric_function=accuracy_score,
        ml_types=[MLType.BINARY, MLType.MULTICLASS],
        higher_is_better=True,
        prediction_type=PredictionType.PREDICTIONS,
        scorer_string="accuracy",
    )

accbal_metric()

Balanced accuracy metric configuration.

Source code in octopus/metrics/classification.py
@Metrics.register("ACCBAL")
def accbal_metric() -> Metric:
    """Balanced accuracy metric configuration."""
    return Metric(
        name="ACCBAL",
        metric_function=balanced_accuracy_score,
        ml_types=[MLType.BINARY, MLType.MULTICLASS],
        higher_is_better=True,
        prediction_type=PredictionType.PREDICTIONS,
        scorer_string="balanced_accuracy",
    )

accbal_multiclass_metric()

Balanced accuracy metric configuration for multiclass problems.

Source code in octopus/metrics/multiclass.py
@Metrics.register("ACCBAL_MC")
def accbal_multiclass_metric() -> Metric:
    """Balanced accuracy metric configuration for multiclass problems."""
    return Metric(
        name="ACCBAL_MC",
        metric_function=balanced_accuracy_score,
        ml_types=[MLType.MULTICLASS],
        higher_is_better=True,
        prediction_type=PredictionType.PREDICTIONS,
        scorer_string="balanced_accuracy",
    )

aucpr_metric()

AUCPR metric configuration.

Source code in octopus/metrics/classification.py
@Metrics.register("AUCPR")
def aucpr_metric() -> Metric:
    """AUCPR metric configuration."""
    return Metric(
        name="AUCPR",
        metric_function=average_precision_score,
        ml_types=[MLType.BINARY],
        higher_is_better=True,
        prediction_type=PredictionType.PROBABILITIES,
        scorer_string="average_precision",
    )

aucroc_macro_multiclass_metric()

AUCROC metric configuration for multiclass problems (macro-average).

Source code in octopus/metrics/multiclass.py
@Metrics.register("AUCROC_MACRO")
def aucroc_macro_multiclass_metric() -> Metric:
    """AUCROC metric configuration for multiclass problems (macro-average)."""
    return Metric(
        name="AUCROC_MACRO",
        metric_function=roc_auc_score,
        metric_params={"multi_class": "ovr", "average": "macro"},
        ml_types=[MLType.MULTICLASS],
        higher_is_better=True,
        prediction_type=PredictionType.PROBABILITIES,
        scorer_string="roc_auc_ovr",
    )

aucroc_metric()

AUCROC metric configuration.

Source code in octopus/metrics/classification.py
@Metrics.register("AUCROC")
def aucroc_metric() -> Metric:
    """AUCROC metric configuration."""
    return Metric(
        name="AUCROC",
        metric_function=roc_auc_score,
        ml_types=[MLType.BINARY],
        higher_is_better=True,
        prediction_type=PredictionType.PROBABILITIES,
        scorer_string="roc_auc",
    )

aucroc_weighted_multiclass_metric()

AUCROC metric configuration for multiclass problems (weighted-average).

Source code in octopus/metrics/multiclass.py
@Metrics.register("AUCROC_WEIGHTED")
def aucroc_weighted_multiclass_metric() -> Metric:
    """AUCROC metric configuration for multiclass problems (weighted-average)."""
    return Metric(
        name="AUCROC_WEIGHTED",
        metric_function=roc_auc_score,
        metric_params={"multi_class": "ovr", "average": "weighted"},
        ml_types=[MLType.MULTICLASS],
        higher_is_better=True,
        prediction_type=PredictionType.PROBABILITIES,
        scorer_string="roc_auc_ovr_weighted",
    )

cindex_metric()

Harrell's concordance index metric configuration.

Source code in octopus/metrics/timetoevent.py
@Metrics.register("CI")
def cindex_metric() -> Metric:
    """Harrell's concordance index metric configuration."""
    return Metric(
        name="CI",
        metric_function=_harrell_concordance_index,
        ml_types=[MLType.TIMETOEVENT],
        higher_is_better=True,
        prediction_type=PredictionType.PREDICTIONS,
        scorer_string="concordance_index",
    )

cindex_uno_metric()

Uno's concordance index metric configuration.

Source code in octopus/metrics/timetoevent.py
@Metrics.register("CI_UNO")
def cindex_uno_metric() -> Metric:
    """Uno's concordance index metric configuration."""
    return Metric(
        name="CI_UNO",
        metric_function=_uno_concordance_index,
        ml_types=[MLType.TIMETOEVENT],
        higher_is_better=True,
        prediction_type=PredictionType.PREDICTIONS,
        scorer_string="concordance_index_uno",
    )

f1_metric()

F1 metric configuration.

Source code in octopus/metrics/classification.py
@Metrics.register("F1")
def f1_metric() -> Metric:
    """F1 metric configuration."""
    return Metric(
        name="F1",
        metric_function=f1_score,
        ml_types=[MLType.BINARY],
        higher_is_better=True,
        prediction_type=PredictionType.PREDICTIONS,
        scorer_string="f1",
    )

logloss_metric()

Log loss metric configuration.

Source code in octopus/metrics/classification.py
@Metrics.register("LOGLOSS")
def logloss_metric() -> Metric:
    """Log loss metric configuration."""
    return Metric(
        name="LOGLOSS",
        metric_function=log_loss,
        ml_types=[MLType.BINARY, MLType.MULTICLASS],
        higher_is_better=True,
        prediction_type=PredictionType.PROBABILITIES,
        scorer_string="neg_log_loss",
    )

mae_metric()

MAE metric configuration.

Source code in octopus/metrics/regression.py
@Metrics.register("MAE")
def mae_metric() -> Metric:
    """MAE metric configuration."""
    return Metric(
        name="MAE",
        metric_function=mean_absolute_error,
        ml_types=[MLType.REGRESSION],
        higher_is_better=False,
        prediction_type=PredictionType.PREDICTIONS,
        scorer_string="neg_mean_absolute_error",
    )

mcc_metric()

Matthews Correlation Coefficient metric configuration.

Source code in octopus/metrics/classification.py
@Metrics.register("MCC")
def mcc_metric() -> Metric:
    """Matthews Correlation Coefficient metric configuration."""
    return Metric(
        name="MCC",
        metric_function=matthews_corrcoef,
        ml_types=[MLType.BINARY, MLType.MULTICLASS],
        higher_is_better=True,
        prediction_type=PredictionType.PREDICTIONS,
        scorer_string="matthews_corrcoef",
    )

mse_metric()

MSE metric configuration.

Source code in octopus/metrics/regression.py
@Metrics.register("MSE")
def mse_metric() -> Metric:
    """MSE metric configuration."""
    return Metric(
        name="MSE",
        metric_function=mean_squared_error,
        ml_types=[MLType.REGRESSION],
        higher_is_better=False,
        prediction_type=PredictionType.PREDICTIONS,
        scorer_string="neg_mean_squared_error",
    )

negbrierscore_metric()

Brier score metric configuration.

Source code in octopus/metrics/classification.py
@Metrics.register("NEGBRIERSCORE")
def negbrierscore_metric() -> Metric:
    """Brier score metric configuration."""
    return Metric(
        name="NEGBRIERSCORE",
        metric_function=brier_score_loss,
        ml_types=[MLType.BINARY],
        higher_is_better=True,
        prediction_type=PredictionType.PROBABILITIES,
        scorer_string="neg_brier_score",
    )

precision_metric()

Precision metric configuration.

Source code in octopus/metrics/classification.py
@Metrics.register("PRECISION")
def precision_metric() -> Metric:
    """Precision metric configuration."""
    return Metric(
        name="PRECISION",
        metric_function=precision_score,
        ml_types=[MLType.BINARY],
        higher_is_better=True,
        prediction_type=PredictionType.PREDICTIONS,
        scorer_string="precision",
    )

r2_metric()

R2 metric configuration.

Source code in octopus/metrics/regression.py
@Metrics.register("R2")
def r2_metric() -> Metric:
    """R2 metric configuration."""
    return Metric(
        name="R2",
        metric_function=r2_score,
        ml_types=[MLType.REGRESSION],
        higher_is_better=True,
        prediction_type=PredictionType.PREDICTIONS,
        scorer_string="r2",
    )

recall_metric()

Recall metric configuration.

Source code in octopus/metrics/classification.py
@Metrics.register("RECALL")
def recall_metric() -> Metric:
    """Recall metric configuration."""
    return Metric(
        name="RECALL",
        metric_function=recall_score,
        ml_types=[MLType.BINARY],
        higher_is_better=True,
        prediction_type=PredictionType.PREDICTIONS,
        scorer_string="recall",
    )

rmse_metric()

RMSE metric configuration.

Source code in octopus/metrics/regression.py
@Metrics.register("RMSE")
def rmse_metric() -> Metric:
    """RMSE metric configuration."""
    return Metric(
        name="RMSE",
        metric_function=root_mean_squared_error,
        ml_types=[MLType.REGRESSION],
        higher_is_better=False,
        prediction_type=PredictionType.PREDICTIONS,
        scorer_string="neg_root_mean_squared_error",
    )

root_mean_squared_error(y_true, y_pred)

Calculate Root Mean Squared Error (RMSE).

Parameters:

Name Type Description Default
y_true ndarray

True target values

required
y_pred ndarray

Predicted target values

required

Returns:

Name Type Description
float float

RMSE value

Source code in octopus/metrics/regression.py
def root_mean_squared_error(y_true: np.ndarray, y_pred: np.ndarray) -> float:
    """Calculate Root Mean Squared Error (RMSE).

    Args:
        y_true: True target values
        y_pred: Predicted target values

    Returns:
        float: RMSE value
    """
    return math.sqrt(mean_squared_error(y_true, y_pred))