-
Notifications
You must be signed in to change notification settings - Fork 91
Add Brier metrics to motion forecasting evaluation module #44
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,7 +16,6 @@ def compute_ade(forecasted_trajectories: NDArrayNumber, gt_trajectory: NDArrayNu | |
| Returns: | ||
| (K,) Average displacement error for each of the predicted trajectories. | ||
| """ | ||
| # (K,N) | ||
| displacement_errors = np.linalg.norm(forecasted_trajectories - gt_trajectory, axis=2) # type: ignore | ||
| ade: NDArrayFloat = np.mean(displacement_errors, axis=1) | ||
| return ade | ||
|
|
@@ -56,3 +55,89 @@ def compute_is_missed_prediction( | |
| fde = compute_fde(forecasted_trajectories, gt_trajectory) | ||
| is_missed_prediction = fde > miss_threshold_m # type: ignore | ||
| return is_missed_prediction | ||
|
|
||
|
|
||
| def compute_brier_ade( | ||
| forecasted_trajectories: NDArrayNumber, | ||
| gt_trajectory: NDArrayNumber, | ||
| forecast_probabilities: NDArrayNumber, | ||
| normalize: bool = False, | ||
| ) -> NDArrayFloat: | ||
| """Compute a probability-weighted (using Brier score) ADE for K predicted trajectories (for the same actor). | ||
|
|
||
| Args: | ||
| forecasted_trajectories: (K, N, 2) predicted trajectories, each N timestamps in length. | ||
| gt_trajectory: (N, 2) ground truth trajectory. | ||
| forecast_probabilities: (K,) probabilities associated with each prediction. | ||
| normalize: Normalizes `forecast_probabilities` to sum to 1 when set to True. | ||
|
|
||
| Returns: | ||
| (K,) Probability-weighted average displacement error for each predicted trajectory. | ||
| """ | ||
| # Compute ADE with Brier score component | ||
| brier_score = _compute_brier_score(forecasted_trajectories, forecast_probabilities, normalize) | ||
| ade_vector = compute_ade(forecasted_trajectories, gt_trajectory) | ||
| brier_ade: NDArrayFloat = ade_vector + brier_score | ||
| return brier_ade | ||
|
|
||
|
|
||
| def compute_brier_fde( | ||
| forecasted_trajectories: NDArrayNumber, | ||
| gt_trajectory: NDArrayNumber, | ||
| forecast_probabilities: NDArrayNumber, | ||
| normalize: bool = False, | ||
| ) -> NDArrayFloat: | ||
| """Compute a probability-weighted (using Brier score) FDE for K predicted trajectories (for the same actor). | ||
|
|
||
| Args: | ||
| forecasted_trajectories: (K, N, 2) predicted trajectories, each N timestamps in length. | ||
| gt_trajectory: (N, 2) ground truth trajectory. | ||
| forecast_probabilities: (K,) probabilities associated with each prediction. | ||
| normalize: Normalizes `forecast_probabilities` to sum to 1 when set to True. | ||
|
|
||
| Returns: | ||
| (K,) Probability-weighted final displacement error for each predicted trajectory. | ||
| """ | ||
| # Compute FDE with Brier score component | ||
| brier_score = _compute_brier_score(forecasted_trajectories, forecast_probabilities, normalize) | ||
| fde_vector = compute_fde(forecasted_trajectories, gt_trajectory) | ||
| brier_fde: NDArrayFloat = fde_vector + brier_score | ||
| return brier_fde | ||
|
|
||
|
|
||
| def _compute_brier_score( | ||
| forecasted_trajectories: NDArrayNumber, | ||
| forecast_probabilities: NDArrayNumber, | ||
| normalize: bool = False, | ||
| ) -> NDArrayFloat: | ||
| """Compute Brier score for K predicted trajectories. | ||
|
|
||
| Note: This function computes Brier score under the assumption that each trajectory is the true "best" prediction | ||
| (i.e. each predicted trajectory has a ground truth probability of 1.0). | ||
|
|
||
| Args: | ||
| forecasted_trajectories: (K, N, 2) predicted trajectories, each N timestamps in length. | ||
| forecast_probabilities: (K,) probabilities associated with each prediction. | ||
| normalize: Normalizes `forecast_probabilities` to sum to 1 when set to True. | ||
|
|
||
| Raises: | ||
| ValueError: If the number of forecasted trajectories and probabilities don't match. | ||
| ValueError: If normalize=False and `forecast_probabilities` contains values outside of the range [0, 1]. | ||
|
|
||
| Returns: | ||
| (K,) Brier score for each predicted trajectory. | ||
| """ | ||
| # Validate that # of forecast probabilities matches forecasted trajectories | ||
| if len(forecasted_trajectories) != len(forecast_probabilities): | ||
| raise ValueError() | ||
|
|
||
| # Validate that all forecast probabilities are in the range [0, 1] | ||
| if np.logical_or(forecast_probabilities < 0.0, forecast_probabilities > 1.0).any(): | ||
| raise ValueError("At least one forecast probability falls outside the range [0, 1].") | ||
|
|
||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The two functions differ by just 1 line. Maybe move most of the stuff to a common function. |
||
| # If enabled, normalize forecast probabilities to sum to 1 | ||
| if normalize: | ||
| forecast_probabilities = forecast_probabilities / np.sum(forecast_probabilities) | ||
|
|
||
| brier_score: NDArrayFloat = np.square((1 - forecast_probabilities)) | ||
| return brier_score | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,6 +2,8 @@ | |
|
|
||
| """Unit tests for motion forecasting metrics.""" | ||
|
|
||
| from contextlib import AbstractContextManager | ||
| from contextlib import nullcontext as does_not_raise | ||
| from typing import Final | ||
|
|
||
| import numpy as np | ||
|
|
@@ -114,3 +116,147 @@ def test_compute_is_missed_prediction( | |
| # Check that is_missed labels are of the correct shape and have the correct value | ||
| assert is_missed_prediction.shape == forecasted_trajectories.shape[:1] | ||
| assert np.all(is_missed_prediction == expected_is_missed_label) | ||
|
|
||
|
|
||
| uniform_probabilities_k1: NDArrayFloat = np.ones((1,)) | ||
| uniform_probabilities_k6: NDArrayFloat = np.ones((6,)) / 6 | ||
| confident_probabilities_k6: NDArrayFloat = np.array([0.9, 0.02, 0.02, 0.02, 0.02, 0.02]) | ||
| non_normalized_probabilities_k6: NDArrayFloat = confident_probabilities_k6 / 10 | ||
| out_of_range_probabilities_k6: NDArrayFloat = confident_probabilities_k6 * 10 | ||
| wrong_shape_probabilities_k6: NDArrayFloat = np.ones((5,)) / 5 | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [nit] One unit test for out of range probs |
||
|
|
||
| expected_bade_uniform_k1 = expected_ade_stationary_k1 | ||
| expected_bade_uniform_k6 = expected_ade_straight_k6 + np.square((1 - uniform_probabilities_k6)) | ||
| expected_bade_confident_k6 = expected_ade_straight_k6 + np.square((1 - confident_probabilities_k6)) | ||
|
|
||
| expected_bfde_uniform_k1 = expected_fde_stationary_k1 | ||
| expected_bfde_uniform_k6 = expected_fde_straight_k6 + np.square((1 - uniform_probabilities_k6)) | ||
| expected_bfde_confident_k6 = expected_fde_straight_k6 + np.square((1 - confident_probabilities_k6)) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "forecasted_trajectories, forecast_probabilities, normalize, expected_brier_ade", | ||
| [ | ||
| (forecasted_trajectories_stationary_k1, uniform_probabilities_k1, False, expected_bade_uniform_k1), | ||
| (forecasted_trajectories_straight_k6, uniform_probabilities_k6, False, expected_bade_uniform_k6), | ||
| (forecasted_trajectories_straight_k6, confident_probabilities_k6, False, expected_bade_confident_k6), | ||
| (forecasted_trajectories_straight_k6, non_normalized_probabilities_k6, True, expected_bade_confident_k6), | ||
| ], | ||
| ids=[ | ||
| "uniform_stationary_k1", | ||
| "uniform_k6", | ||
| "confident_k6", | ||
| "normalize_probabilities_k6", | ||
| ], | ||
| ) | ||
| def test_compute_brier_ade( | ||
| forecasted_trajectories: NDArrayNumber, | ||
| forecast_probabilities: NDArrayNumber, | ||
| normalize: bool, | ||
| expected_brier_ade: NDArrayFloat, | ||
| ) -> None: | ||
| """Test that test_compute_brier_ade returns the correct output with valid inputs. | ||
|
|
||
| Args: | ||
| forecasted_trajectories: Forecasted trajectories for test case. | ||
| forecast_probabilities: Forecast probabilities for test case. | ||
| normalize: Controls whether forecast probabilities should be normalized for test case. | ||
| expected_brier_ade: Expected probability-weighted ADE for test case. | ||
| """ | ||
| brier_ade = metrics.compute_brier_ade( | ||
| forecasted_trajectories, _STATIONARY_GT_TRAJ, forecast_probabilities, normalize | ||
| ) | ||
| assert np.allclose(brier_ade, expected_brier_ade) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "forecast_probabilities, normalize, expectation", | ||
| [ | ||
| (non_normalized_probabilities_k6, True, does_not_raise()), | ||
| (out_of_range_probabilities_k6, False, pytest.raises(ValueError)), | ||
| (wrong_shape_probabilities_k6, True, pytest.raises(ValueError)), | ||
| ], | ||
| ids=[ | ||
| "valid_probabilities", | ||
| "out_of_range_probabilities", | ||
| "wrong_shape_probabilities", | ||
| ], | ||
| ) | ||
| def test_compute_brier_ade_data_validation( | ||
| forecast_probabilities: NDArrayNumber, normalize: bool, expectation: AbstractContextManager # type: ignore | ||
| ) -> None: | ||
| """Test that test_compute_brier_ade raises the correct errors when inputs are invalid. | ||
|
|
||
| Args: | ||
| forecast_probabilities: Forecast probabilities for test case. | ||
| normalize: Controls whether forecast probabilities should be normalized for test case. | ||
| expectation: Context manager to capture the expected exception for each test case. | ||
| """ | ||
| with expectation: | ||
| metrics.compute_brier_ade( | ||
| forecasted_trajectories_straight_k6, _STATIONARY_GT_TRAJ, forecast_probabilities, normalize | ||
| ) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "forecasted_trajectories, forecast_probabilities, normalize, expected_brier_fde", | ||
| [ | ||
| (forecasted_trajectories_stationary_k1, uniform_probabilities_k1, False, expected_bfde_uniform_k1), | ||
| (forecasted_trajectories_straight_k6, uniform_probabilities_k6, False, expected_bfde_uniform_k6), | ||
| (forecasted_trajectories_straight_k6, confident_probabilities_k6, False, expected_bfde_confident_k6), | ||
| (forecasted_trajectories_straight_k6, non_normalized_probabilities_k6, True, expected_bfde_confident_k6), | ||
| ], | ||
| ids=[ | ||
| "uniform_stationary_k1", | ||
| "uniform_k6", | ||
| "confident_k6", | ||
| "normalize_probabilities_k6", | ||
| ], | ||
| ) | ||
| def test_compute_brier_fde( | ||
| forecasted_trajectories: NDArrayNumber, | ||
| forecast_probabilities: NDArrayNumber, | ||
| normalize: bool, | ||
| expected_brier_fde: NDArrayFloat, | ||
| ) -> None: | ||
| """Test that test_compute_brier_fde returns the correct output with valid inputs. | ||
|
|
||
| Args: | ||
| forecasted_trajectories: Forecasted trajectories for test case. | ||
| forecast_probabilities: Forecast probabilities for test case. | ||
| normalize: Controls whether forecast probabilities should be normalized for test case. | ||
| expected_brier_fde: Expected probability-weighted FDE for test case. | ||
| """ | ||
| brier_fde = metrics.compute_brier_fde( | ||
| forecasted_trajectories, _STATIONARY_GT_TRAJ, forecast_probabilities, normalize | ||
| ) | ||
| assert np.allclose(brier_fde, expected_brier_fde) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "forecast_probabilities, normalize, expectation", | ||
| [ | ||
| (non_normalized_probabilities_k6, True, does_not_raise()), | ||
| (out_of_range_probabilities_k6, False, pytest.raises(ValueError)), | ||
| (wrong_shape_probabilities_k6, True, pytest.raises(ValueError)), | ||
| ], | ||
| ids=[ | ||
| "valid_probabilities", | ||
| "out_of_range_probabilities", | ||
| "wrong_shape_probabilities", | ||
| ], | ||
| ) | ||
| def test_compute_brier_fde_data_validation( | ||
| forecast_probabilities: NDArrayNumber, normalize: bool, expectation: AbstractContextManager # type: ignore | ||
| ) -> None: | ||
| """Test that test_compute_brier_fde raises the correct errors when inputs are invalid. | ||
|
|
||
| Args: | ||
| forecast_probabilities: Forecast probabilities for test case. | ||
| normalize: Controls whether forecast probabilities should be normalized for test case. | ||
| expectation: Context manager to capture the expected exception for each test case. | ||
| """ | ||
| with expectation: | ||
| metrics.compute_brier_fde( | ||
| forecasted_trajectories_straight_k6, _STATIONARY_GT_TRAJ, forecast_probabilities, normalize | ||
| ) | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see that all the metric functions here work on a single sample. For a batch, we might have to call these functions for individual samples. Wouldn't that be slower because no batch computation will be used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good point - we can certainly convert all these metric functions into batched equivalents in a follow-up PR. :-)