Skip to content

Commit f1bedd7

Browse files
authored
Fix scenario mining evaluation bug (#311)
* Fix scenario mining evaluation bug * Remove velocity from scenario mining evaluation
1 parent b44ea17 commit f1bedd7

1 file changed

Lines changed: 8 additions & 6 deletions

File tree

  • src/av2/evaluation/scenario_mining

src/av2/evaluation/scenario_mining/eval.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
33
Evaluation Metrics:
44
HOTA: see https://arxiv.org/abs/2009.07736
5-
scenario-level F1: see https://jivp-eurasipjournals.springeropen.com/articles/10.1155/2008/246309
6-
timestamp-level F1: see https://arxiv.org/abs/2008.08063
5+
scenario-level F1
6+
timestamp-level F1
77
"""
88

99
from pathlib import Path
@@ -136,11 +136,13 @@ def filter_drivable_area(tracks: Sequences, dataset_dir: Optional[str]) -> Seque
136136
frame["translation_m"] = frame["translation_m"][is_evaluated]
137137
frame["size"] = frame["size"][is_evaluated]
138138
frame["yaw"] = frame["yaw"][is_evaluated]
139-
frame["velocity_m_per_s"] = frame["velocity_m_per_s"][is_evaluated]
140139
frame["label"] = frame["label"][is_evaluated]
141140
frame["name"] = frame["name"][is_evaluated]
142141
frame["track_id"] = frame["track_id"][is_evaluated]
143142

143+
if "velocity_m_per_s" in frame:
144+
frame["velocity_m_per_s"] = frame["velocity_m_per_s"][is_evaluated]
145+
144146
if "score" in frame:
145147
frame["score"] = frame["score"][is_evaluated]
146148

@@ -251,8 +253,8 @@ def compute_temporal_metrics(
251253
output_dir: The directory to save the plotted confusion matrices.
252254
253255
Returns:
254-
timestamp_f1: The F1 score where each timestamp counts as a prediction to evaluate.
255256
scenario_f1: The F1 score where each log-prompt pair counts as a prediction to evaluate.
257+
timestamp_f1: The F1 score where each timestamp counts as a prediction to evaluate.
256258
257259
258260
"""
@@ -364,7 +366,7 @@ def evaluate(
364366
output_dir = out + "/partial_tracks"
365367
Path(output_dir).mkdir(parents=True, exist_ok=True)
366368

367-
partial_track_hota, timestamp_f1, scenario_f1 = evaluate_scenario_mining(
369+
partial_track_hota, scenario_f1, timestamp_f1 = evaluate_scenario_mining(
368370
track_predictions,
369371
labels,
370372
objective_metric=objective_metric,
@@ -421,8 +423,8 @@ def evaluate_scenario_mining(
421423
422424
Returns:
423425
referred_hota: The HOTA tracking metric applied to all objects with the category REFERRED_OBJECT
424-
timestamp_f1: A retrieval/classification metric for determining if each timestamp contains any instance of the prompt.
425426
scenario_f1: A retrieval/classification metric for determining if each data log contains any instance of the prompt.
427+
timestamp_f1: A retrieval/classification metric for determining if each timestamp contains any instance of the prompt.
426428
"""
427429
classes = list(AV2_CATEGORIES)
428430

0 commit comments

Comments
 (0)