Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 88 additions & 3 deletions asteroid/metrics.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import json
import warnings
import traceback
from collections import Counter
from typing import List

from collections import Counter
import pandas as pd
import numpy as np
from pb_bss_eval import InputMetrics, OutputMetrics
Expand Down Expand Up @@ -30,7 +30,7 @@ def get_metrics(
clean (np.array): reference array.
estimate (np.array): estimate array.
sample_rate (int): sampling rate of the audio clips.
metrics_list (Union [str, list]): List of metrics to compute.
metrics_list (Union[List[str], str): List of metrics to compute.
Defaults to 'all' (['si_sdr', 'sdr', 'sir', 'sar', 'stoi', 'pesq']).
average (bool): Return dict([float]) if True, else dict([array]).
compute_permutation (bool): Whether to compute the permutation on
Expand Down Expand Up @@ -115,6 +115,91 @@ def get_metrics(
return utt_metrics


class MetricTracker:
"""Metric tracker, subject to change.

Args:
sample_rate (int): sampling rate of the audio clips.
metrics_list (Union[List[str], str): List of metrics to compute.
Defaults to 'all' (['si_sdr', 'sdr', 'sir', 'sar', 'stoi', 'pesq']).
average (bool): Return dict([float]) if True, else dict([array]).
compute_permutation (bool): Whether to compute the permutation on
estimate sources for the output metrics (default False)
ignore_metrics_errors (bool): Whether to ignore errors that occur in
computing the metrics. A warning will be printed instead.
"""

def __init__(
self,
sample_rate,
metrics_list=tuple(ALL_METRICS),
average=True,
compute_permutation=False,
ignore_metrics_errors=False,
):
self.sample_rate = sample_rate
# TODO: support WER in metrics_list when merged.
self.metrics_list = metrics_list
self.average = average
self.compute_permutation = compute_permutation
self.ignore_metrics_errors = ignore_metrics_errors

self.series_list = []
self._len_last_saved = 0
self._all_metrics = pd.DataFrame()

def __call__(
self, *, mix: np.ndarray, clean: np.ndarray, estimate: np.ndarray, filename=None, **kwargs
):
"""Compute metrics for mix/clean/estimate and log it to the class.

Args:
mix (np.array): mixture array.
clean (np.array): reference array.
estimate (np.array): estimate array.
sample_rate (int): sampling rate of the audio clips.
filename (str, optional): If computing a metric fails, print this
filename along with the exception/warning message for debugging purposes.
**kwargs: Any key, value pair to log in the utterance metric (filename, speaker ID, etc...)
"""
utt_metrics = get_metrics(
mix,
clean,
estimate,
sample_rate=self.sample_rate,
metrics_list=self.metrics_list,
average=self.average,
compute_permutation=self.compute_permutation,
ignore_metrics_errors=self.ignore_metrics_errors,
filename=filename,
)
utt_metrics.update(kwargs)
self.series_list.append(pd.Series(utt_metrics))

def as_df(self):
"""Return dataframe containing the results (cached)."""
if self._len_last_saved == len(self.series_list):
return self._all_metrics
self._len_last_saved = len(self.series_list)
self._all_metrics = pd.DataFrame(self.series_list)
return pd.DataFrame(self.series_list)

def final_report(self, dump_path: str = None):
"""Return dict of average metrics. Dump to JSON if `dump_path` is not None."""
final_results = {}
metrics_df = self.as_df()
for metric_name in self.metrics_list:
input_metric_name = "input_" + metric_name
ldf = metrics_df[metric_name] - metrics_df[input_metric_name]
final_results[metric_name] = metrics_df[metric_name].mean()
final_results[metric_name + "_imp"] = ldf.mean()
if dump_path is not None:
dump_path = dump_path + ".json" if not dump_path.endswith(".json") else dump_path
with open(dump_path, "w") as f:
json.dump(final_results, f, indent=0)
return final_results


class MockWERTracker:
def __init__(self, *args, **kwargs):
pass
Expand Down
18 changes: 17 additions & 1 deletion tests/metrics_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from unittest import mock
import numpy as np
import pytest
from asteroid.metrics import get_metrics
from asteroid.metrics import get_metrics, MetricTracker


@pytest.mark.parametrize("fs", [8000, 16000])
Expand Down Expand Up @@ -69,3 +69,19 @@ def test_ignore_errors(filename, average):
)
assert metrics_dict["si_sdr"] is None
assert metrics_dict["pesq"] is not None


def test_metric_tracker():
metric_tracker = MetricTracker(sample_rate=8000, metrics_list=["si_sdr", "stoi"])
for i in range(5):
mix = np.random.randn(1, 4000)
clean = np.random.randn(1, 4000)
est = np.random.randn(1, 4000)
metric_tracker(mix=mix, clean=clean, estimate=est, mix_path=f"path{i}")

# Test dump & final report
metric_tracker.final_report()
metric_tracker.final_report(dump_path="final_metrics.json")

# Check that kwargs are passed.
assert "mix_path" in metric_tracker.as_df()