diff --git a/src/av2/datasets/motion_forecasting/eval/metrics.py b/src/av2/datasets/motion_forecasting/eval/metrics.py index 85be281a..4c923643 100644 --- a/src/av2/datasets/motion_forecasting/eval/metrics.py +++ b/src/av2/datasets/motion_forecasting/eval/metrics.py @@ -223,7 +223,7 @@ def compute_world_brier_fde( Args: forecasted_world_trajectories: (M, K, N, 2) K predicted trajectories of length N, for each of M actors. gt_world_trajectories: (M, N, 2) ground truth trajectories of length N, for each of M actors. - forecasted_world_probabilities: (M,) normalized probabilities associated with each world. + forecasted_world_probabilities: (K,) normalized probabilities associated with each world. normalize: Normalizes `forecasted_world_probabilities` to sum to 1 when set to True. Returns: