diff --git a/src/av2/datasets/motion_forecasting/eval/metrics.py b/src/av2/datasets/motion_forecasting/eval/metrics.py index 791c786f..961ecf75 100644 --- a/src/av2/datasets/motion_forecasting/eval/metrics.py +++ b/src/av2/datasets/motion_forecasting/eval/metrics.py @@ -16,7 +16,6 @@ def compute_ade(forecasted_trajectories: NDArrayNumber, gt_trajectory: NDArrayNu Returns: (K,) Average displacement error for each of the predicted trajectories. """ - # (K,N) displacement_errors = np.linalg.norm(forecasted_trajectories - gt_trajectory, axis=2) # type: ignore ade: NDArrayFloat = np.mean(displacement_errors, axis=1) return ade @@ -56,3 +55,89 @@ def compute_is_missed_prediction( fde = compute_fde(forecasted_trajectories, gt_trajectory) is_missed_prediction = fde > miss_threshold_m # type: ignore return is_missed_prediction + + +def compute_brier_ade( + forecasted_trajectories: NDArrayNumber, + gt_trajectory: NDArrayNumber, + forecast_probabilities: NDArrayNumber, + normalize: bool = False, +) -> NDArrayFloat: + """Compute a probability-weighted (using Brier score) ADE for K predicted trajectories (for the same actor). + + Args: + forecasted_trajectories: (K, N, 2) predicted trajectories, each N timestamps in length. + gt_trajectory: (N, 2) ground truth trajectory. + forecast_probabilities: (K,) probabilities associated with each prediction. + normalize: Normalizes `forecast_probabilities` to sum to 1 when set to True. + + Returns: + (K,) Probability-weighted average displacement error for each predicted trajectory. + """ + # Compute ADE with Brier score component + brier_score = _compute_brier_score(forecasted_trajectories, forecast_probabilities, normalize) + ade_vector = compute_ade(forecasted_trajectories, gt_trajectory) + brier_ade: NDArrayFloat = ade_vector + brier_score + return brier_ade + + +def compute_brier_fde( + forecasted_trajectories: NDArrayNumber, + gt_trajectory: NDArrayNumber, + forecast_probabilities: NDArrayNumber, + normalize: bool = False, +) -> NDArrayFloat: + """Compute a probability-weighted (using Brier score) FDE for K predicted trajectories (for the same actor). + + Args: + forecasted_trajectories: (K, N, 2) predicted trajectories, each N timestamps in length. + gt_trajectory: (N, 2) ground truth trajectory. + forecast_probabilities: (K,) probabilities associated with each prediction. + normalize: Normalizes `forecast_probabilities` to sum to 1 when set to True. + + Returns: + (K,) Probability-weighted final displacement error for each predicted trajectory. + """ + # Compute FDE with Brier score component + brier_score = _compute_brier_score(forecasted_trajectories, forecast_probabilities, normalize) + fde_vector = compute_fde(forecasted_trajectories, gt_trajectory) + brier_fde: NDArrayFloat = fde_vector + brier_score + return brier_fde + + +def _compute_brier_score( + forecasted_trajectories: NDArrayNumber, + forecast_probabilities: NDArrayNumber, + normalize: bool = False, +) -> NDArrayFloat: + """Compute Brier score for K predicted trajectories. + + Note: This function computes Brier score under the assumption that each trajectory is the true "best" prediction + (i.e. each predicted trajectory has a ground truth probability of 1.0). + + Args: + forecasted_trajectories: (K, N, 2) predicted trajectories, each N timestamps in length. + forecast_probabilities: (K,) probabilities associated with each prediction. + normalize: Normalizes `forecast_probabilities` to sum to 1 when set to True. + + Raises: + ValueError: If the number of forecasted trajectories and probabilities don't match. + ValueError: If normalize=False and `forecast_probabilities` contains values outside of the range [0, 1]. + + Returns: + (K,) Brier score for each predicted trajectory. + """ + # Validate that # of forecast probabilities matches forecasted trajectories + if len(forecasted_trajectories) != len(forecast_probabilities): + raise ValueError() + + # Validate that all forecast probabilities are in the range [0, 1] + if np.logical_or(forecast_probabilities < 0.0, forecast_probabilities > 1.0).any(): + raise ValueError("At least one forecast probability falls outside the range [0, 1].") + + # If enabled, normalize forecast probabilities to sum to 1 + if normalize: + forecast_probabilities = forecast_probabilities / np.sum(forecast_probabilities) + + brier_score: NDArrayFloat = np.square((1 - forecast_probabilities)) + return brier_score diff --git a/tests/datasets/motion_forecasting/eval/test_metrics.py b/tests/datasets/motion_forecasting/eval/test_metrics.py index 4a4f12e1..4ffd4afa 100644 --- a/tests/datasets/motion_forecasting/eval/test_metrics.py +++ b/tests/datasets/motion_forecasting/eval/test_metrics.py @@ -2,6 +2,8 @@ """Unit tests for motion forecasting metrics.""" +from contextlib import AbstractContextManager +from contextlib import nullcontext as does_not_raise from typing import Final import numpy as np @@ -114,3 +116,147 @@ def test_compute_is_missed_prediction( # Check that is_missed labels are of the correct shape and have the correct value assert is_missed_prediction.shape == forecasted_trajectories.shape[:1] assert np.all(is_missed_prediction == expected_is_missed_label) + + +uniform_probabilities_k1: NDArrayFloat = np.ones((1,)) +uniform_probabilities_k6: NDArrayFloat = np.ones((6,)) / 6 +confident_probabilities_k6: NDArrayFloat = np.array([0.9, 0.02, 0.02, 0.02, 0.02, 0.02]) +non_normalized_probabilities_k6: NDArrayFloat = confident_probabilities_k6 / 10 +out_of_range_probabilities_k6: NDArrayFloat = confident_probabilities_k6 * 10 +wrong_shape_probabilities_k6: NDArrayFloat = np.ones((5,)) / 5 + +expected_bade_uniform_k1 = expected_ade_stationary_k1 +expected_bade_uniform_k6 = expected_ade_straight_k6 + np.square((1 - uniform_probabilities_k6)) +expected_bade_confident_k6 = expected_ade_straight_k6 + np.square((1 - confident_probabilities_k6)) + +expected_bfde_uniform_k1 = expected_fde_stationary_k1 +expected_bfde_uniform_k6 = expected_fde_straight_k6 + np.square((1 - uniform_probabilities_k6)) +expected_bfde_confident_k6 = expected_fde_straight_k6 + np.square((1 - confident_probabilities_k6)) + + +@pytest.mark.parametrize( + "forecasted_trajectories, forecast_probabilities, normalize, expected_brier_ade", + [ + (forecasted_trajectories_stationary_k1, uniform_probabilities_k1, False, expected_bade_uniform_k1), + (forecasted_trajectories_straight_k6, uniform_probabilities_k6, False, expected_bade_uniform_k6), + (forecasted_trajectories_straight_k6, confident_probabilities_k6, False, expected_bade_confident_k6), + (forecasted_trajectories_straight_k6, non_normalized_probabilities_k6, True, expected_bade_confident_k6), + ], + ids=[ + "uniform_stationary_k1", + "uniform_k6", + "confident_k6", + "normalize_probabilities_k6", + ], +) +def test_compute_brier_ade( + forecasted_trajectories: NDArrayNumber, + forecast_probabilities: NDArrayNumber, + normalize: bool, + expected_brier_ade: NDArrayFloat, +) -> None: + """Test that test_compute_brier_ade returns the correct output with valid inputs. + + Args: + forecasted_trajectories: Forecasted trajectories for test case. + forecast_probabilities: Forecast probabilities for test case. + normalize: Controls whether forecast probabilities should be normalized for test case. + expected_brier_ade: Expected probability-weighted ADE for test case. + """ + brier_ade = metrics.compute_brier_ade( + forecasted_trajectories, _STATIONARY_GT_TRAJ, forecast_probabilities, normalize + ) + assert np.allclose(brier_ade, expected_brier_ade) + + +@pytest.mark.parametrize( + "forecast_probabilities, normalize, expectation", + [ + (non_normalized_probabilities_k6, True, does_not_raise()), + (out_of_range_probabilities_k6, False, pytest.raises(ValueError)), + (wrong_shape_probabilities_k6, True, pytest.raises(ValueError)), + ], + ids=[ + "valid_probabilities", + "out_of_range_probabilities", + "wrong_shape_probabilities", + ], +) +def test_compute_brier_ade_data_validation( + forecast_probabilities: NDArrayNumber, normalize: bool, expectation: AbstractContextManager # type: ignore +) -> None: + """Test that test_compute_brier_ade raises the correct errors when inputs are invalid. + + Args: + forecast_probabilities: Forecast probabilities for test case. + normalize: Controls whether forecast probabilities should be normalized for test case. + expectation: Context manager to capture the expected exception for each test case. + """ + with expectation: + metrics.compute_brier_ade( + forecasted_trajectories_straight_k6, _STATIONARY_GT_TRAJ, forecast_probabilities, normalize + ) + + +@pytest.mark.parametrize( + "forecasted_trajectories, forecast_probabilities, normalize, expected_brier_fde", + [ + (forecasted_trajectories_stationary_k1, uniform_probabilities_k1, False, expected_bfde_uniform_k1), + (forecasted_trajectories_straight_k6, uniform_probabilities_k6, False, expected_bfde_uniform_k6), + (forecasted_trajectories_straight_k6, confident_probabilities_k6, False, expected_bfde_confident_k6), + (forecasted_trajectories_straight_k6, non_normalized_probabilities_k6, True, expected_bfde_confident_k6), + ], + ids=[ + "uniform_stationary_k1", + "uniform_k6", + "confident_k6", + "normalize_probabilities_k6", + ], +) +def test_compute_brier_fde( + forecasted_trajectories: NDArrayNumber, + forecast_probabilities: NDArrayNumber, + normalize: bool, + expected_brier_fde: NDArrayFloat, +) -> None: + """Test that test_compute_brier_fde returns the correct output with valid inputs. + + Args: + forecasted_trajectories: Forecasted trajectories for test case. + forecast_probabilities: Forecast probabilities for test case. + normalize: Controls whether forecast probabilities should be normalized for test case. + expected_brier_fde: Expected probability-weighted FDE for test case. + """ + brier_fde = metrics.compute_brier_fde( + forecasted_trajectories, _STATIONARY_GT_TRAJ, forecast_probabilities, normalize + ) + assert np.allclose(brier_fde, expected_brier_fde) + + +@pytest.mark.parametrize( + "forecast_probabilities, normalize, expectation", + [ + (non_normalized_probabilities_k6, True, does_not_raise()), + (out_of_range_probabilities_k6, False, pytest.raises(ValueError)), + (wrong_shape_probabilities_k6, True, pytest.raises(ValueError)), + ], + ids=[ + "valid_probabilities", + "out_of_range_probabilities", + "wrong_shape_probabilities", + ], +) +def test_compute_brier_fde_data_validation( + forecast_probabilities: NDArrayNumber, normalize: bool, expectation: AbstractContextManager # type: ignore +) -> None: + """Test that test_compute_brier_fde raises the correct errors when inputs are invalid. + + Args: + forecast_probabilities: Forecast probabilities for test case. + normalize: Controls whether forecast probabilities should be normalized for test case. + expectation: Context manager to capture the expected exception for each test case. + """ + with expectation: + metrics.compute_brier_fde( + forecasted_trajectories_straight_k6, _STATIONARY_GT_TRAJ, forecast_probabilities, normalize + )