Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [Not released]

### Added
- Ability to get predictions without postprocessing in the exploration space. (Only back end for
now)
- New Smart Tags `pipeline_disagreement` and `incorrect_for_all_pipelines` as a first step for pipeline comparison.
- Links on top words to filter utterances that contain it.

Expand Down
3 changes: 2 additions & 1 deletion azimuth/dataset_split_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,9 @@ def save_csv(self, table_key=None) -> str:
DatasetColumn.postprocessed_prediction,
DatasetColumn.model_confidences,
DatasetColumn.postprocessed_confidences,
DatasetColumn.model_outcome,
DatasetColumn.postprocessed_outcome,
DatasetColumn.confidence_bin_idx,
DatasetColumn.outcome,
DatasetColumn.token_count,
DatasetColumn.neighbors_train,
DatasetColumn.neighbors_eval,
Expand Down
10 changes: 5 additions & 5 deletions azimuth/modules/base_classes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
from azimuth.modules.base_classes.dask_module import ConfigScope, DaskModule
from azimuth.modules.base_classes.module import Module
from azimuth.modules.base_classes.expirable_mixin import ExpirableMixin
from azimuth.modules.base_classes.aggregation_module import (
AggregationModule,
ComparisonModule,
FilterableModule,
)
from azimuth.modules.base_classes.indexable_module import (
DatasetResultModule,
IndexableModule,
ModelContractModule,
)
from azimuth.modules.base_classes.aggregation_module import (
AggregationModule,
ComparisonModule,
FilterableModule,
)
46 changes: 43 additions & 3 deletions azimuth/modules/base_classes/aggregation_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
# in the root directory of this source tree.
import time
from abc import ABC
from typing import List, Optional
from typing import List, Optional, cast

from datasets import Dataset

from azimuth.modules.base_classes import ConfigScope, ExpirableMixin, Module
from azimuth.types import DatasetSplitName, ModuleOptions, ModuleResponse
from azimuth.types import DatasetColumn, DatasetSplitName, ModuleOptions, ModuleResponse
from azimuth.types.outcomes import OutcomeName
from azimuth.utils.filtering import filter_dataset_split


Expand All @@ -33,7 +34,7 @@ class ComparisonModule(AggregationModule[ConfigScope], ABC):
class FilterableModule(AggregationModule[ConfigScope], ExpirableMixin, ABC):
"""Filterable Module are affected by filters in mod options."""

allowed_mod_options = {"filters", "pipeline_index"}
allowed_mod_options = {"filters", "pipeline_index", "without_postprocessing"}

def __init__(
self,
Expand All @@ -56,3 +57,42 @@ def get_dataset_split(self, name: DatasetSplitName = None) -> Dataset:
"""
ds = super().get_dataset_split(name)
return filter_dataset_split(ds, filters=self.mod_options.filters, config=self.config)

def _get_predictions_from_ds(self) -> List[int]:
"""Get predicted classes according to the module options (with or without postprocessing).

Returns: List of Predictions
"""
ds = self.get_dataset_split()
if self.mod_options.without_postprocessing:
return cast(List[int], [preds[0] for preds in ds[DatasetColumn.model_predictions]])
else:
return cast(List[int], ds[DatasetColumn.postprocessed_prediction])

def _get_confidences_from_ds(self) -> List[List[float]]:
"""Get confidences according to the module options (with or without postprocessing).

Notes: Confidences are sorted according to their values (not the class id).

Returns: List of Confidences
"""
ds = self.get_dataset_split()
confidences = (
ds[DatasetColumn.model_confidences]
if self.mod_options.without_postprocessing
else ds[DatasetColumn.postprocessed_confidences]
)
return cast(List[List[float]], confidences)

def _get_outcomes_from_ds(self) -> List[OutcomeName]:
"""Get outcomes according to the module options (with or without postprocessing).

Returns: List of Outcomes
"""
ds = self.get_dataset_split()
outcomes = (
ds[DatasetColumn.model_outcome]
if self.mod_options.without_postprocessing
else ds[DatasetColumn.postprocessed_outcome]
)
return cast(List[OutcomeName], outcomes)
2 changes: 1 addition & 1 deletion azimuth/modules/model_contracts/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Do not reorder
from azimuth.modules.model_contracts.hf_text_classification import HFTextClassificationModule
from azimuth.modules.model_contracts.custom_classification import CustomTextClassificationModule
from azimuth.modules.model_contracts.file_based_text_classification import (
FileBasedTextClassificationModule,
)
from azimuth.modules.model_contracts.hf_text_classification import HFTextClassificationModule
29 changes: 8 additions & 21 deletions azimuth/modules/model_performance/confidence_binning.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import List

import numpy as np
import structlog
from datasets import Dataset

from azimuth.config import ModelContractConfig
Expand All @@ -21,39 +20,28 @@

CONFIDENCE_BINS_COUNT = 20

log = structlog.get_logger()


class ConfidenceBinningTask:
"""Common functions to modules related to confidence bins."""

@staticmethod
def get_outcome_mask(ds: Dataset, outcome: OutcomeName) -> List[bool]:
return [utterance_outcome == outcome for utterance_outcome in ds[DatasetColumn.outcome]]

@staticmethod
def get_confidence_interval() -> np.ndarray:
return np.linspace(0, 1, CONFIDENCE_BINS_COUNT + 1)


class ConfidenceHistogramModule(FilterableModule[ModelContractConfig], ConfidenceBinningTask):
class ConfidenceHistogramModule(FilterableModule[ModelContractConfig]):
"""Return a confidence histogram of the predictions."""

def get_outcome_mask(self, outcome: OutcomeName) -> List[bool]:
return [utterance_outcome == outcome for utterance_outcome in self._get_outcomes_from_ds()]

def compute_on_dataset_split(self) -> List[ConfidenceHistogramResponse]: # type: ignore
"""Compute the confidence histogram with CONFIDENCE_BINS_COUNT bins on the dataset split.

Returns:
List of the confidence bins with their confidence and the outcome count.

"""
bins = self.get_confidence_interval()
bins = np.linspace(0, 1, CONFIDENCE_BINS_COUNT + 1)

ds: Dataset = assert_not_none(self.get_dataset_split())

result = []
if len(ds) > 0:
# Get the bin index for each prediction.
confidences = np.max(ds[DatasetColumn.postprocessed_confidences], axis=1)
confidences = np.max(self._get_confidences_from_ds(), axis=1)
bin_indices = np.floor(confidences * CONFIDENCE_BINS_COUNT)

# Create the records. We drop the last bin as it's the maximum.
Expand All @@ -62,7 +50,7 @@ def compute_on_dataset_split(self) -> List[ConfidenceHistogramResponse]: # type
outcome_count = defaultdict(int)
for outcome in ALL_OUTCOMES:
outcome_count[outcome] = np.logical_and(
bin_mask, self.get_outcome_mask(ds, outcome)
bin_mask, self.get_outcome_mask(outcome)
).sum()
mean_conf = (
0 if bin_mask.sum() == 0 else np.nan_to_num(confidences[bin_mask].mean())
Expand Down Expand Up @@ -90,7 +78,7 @@ def compute_on_dataset_split(self) -> List[ConfidenceHistogramResponse]: # type
return [ConfidenceHistogramResponse(details_all_bins=result)]


class ConfidenceBinIndexModule(DatasetResultModule[ModelContractConfig], ConfidenceBinningTask):
class ConfidenceBinIndexModule(DatasetResultModule[ModelContractConfig]):
"""Return confidence bin indices for the selected dataset split."""

allowed_mod_options = DatasetResultModule.allowed_mod_options | {"threshold", "pipeline_index"}
Expand All @@ -102,7 +90,6 @@ def compute_on_dataset_split(self) -> List[int]: # type: ignore
List of bin indices for all utterances.

"""
self.get_confidence_interval()
ds = assert_not_none(self.get_dataset_split())

bin_indices: List[int] = (
Expand Down
3 changes: 1 addition & 2 deletions azimuth/modules/model_performance/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

from azimuth.config import ModelContractConfig
from azimuth.modules.base_classes import FilterableModule
from azimuth.types import DatasetColumn
from azimuth.types.model_performance import ConfusionMatrixResponse
from azimuth.utils.validation import assert_not_none

Expand All @@ -26,7 +25,7 @@ def compute_on_dataset_split(self) -> List[ConfusionMatrixResponse]: # type: ig
"""
ds: Dataset = assert_not_none(self.get_dataset_split())
predictions, labels = (
ds[DatasetColumn.postprocessed_prediction],
self._get_predictions_from_ds(),
ds[self.config.columns.label],
)
ds_mng = self.get_dataset_split_manager()
Expand Down
54 changes: 24 additions & 30 deletions azimuth/modules/model_performance/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
ConfidenceHistogramModule,
)
from azimuth.plots.ece import make_ece_figure
from azimuth.types import DatasetColumn, DatasetFilters, ModuleOptions
from azimuth.types import DatasetColumn, DatasetFilters
from azimuth.types.model_performance import (
MetricsAPIResponse,
MetricsModuleResponse,
Expand Down Expand Up @@ -51,27 +51,6 @@ def first_value(di: Optional[Dict]) -> Optional[float]:
return next(iter(di.values()), None)


def make_probabilities(dataset: Dataset, num_classes: int) -> np.ndarray:
"""Make probabilities from dataset columns.

Args:
dataset: Dataset holding predictions and confidence.
num_classes: Number of classes

Returns:
Array with shape [len(dataset), num_classes] with probabilities.
"""
probs = np.zeros([len(dataset), num_classes])
for idx, (confidences, predictions) in enumerate(
zip(
dataset[DatasetColumn.postprocessed_confidences],
dataset[DatasetColumn.model_predictions],
)
):
probs[idx] = np.array(confidences)[predictions]
return probs


class MetricsModule(FilterableModule[ModelContractConfig]):
"""Computes different metrics on each dataset split."""

Expand All @@ -83,16 +62,14 @@ def compute_on_dataset_split(self) -> List[MetricsModuleResponse]: # type: igno
return [BASE_RESPONSE]

utterance_count = len(indices)
outcome_count = Counter(ds[DatasetColumn.outcome])
outcome_count = Counter(self._get_outcomes_from_ds())
outcome_count.update({outcome: 0 for outcome in ALL_OUTCOMES})

# Compute ECE
conf_hist_mod = ConfidenceHistogramModule(
dataset_split_name=self.dataset_split_name,
config=self.config,
mod_options=ModuleOptions(
filters=self.mod_options.filters, pipeline_index=self.mod_options.pipeline_index
),
mod_options=self.mod_options,
)
bins = conf_hist_mod.compute_on_dataset_split()[0].details_all_bins
ece, acc, expected = compute_ece_from_bins(bins)
Expand All @@ -109,9 +86,7 @@ def compute_on_dataset_split(self) -> List[MetricsModuleResponse]: # type: igno
)
accept_probabilities = "probabilities" in inspect.signature(met._compute).parameters
extra_kwargs = (
dict(probabilities=make_probabilities(ds, dm.get_num_classes(labels_only=True)))
if accept_probabilities
else {}
dict(probabilities=self.make_probabilities()) if accept_probabilities else {}
)
extra_kwargs.update(metric_obj_def.additional_kwargs)
with warnings.catch_warnings():
Expand All @@ -121,7 +96,7 @@ def compute_on_dataset_split(self) -> List[MetricsModuleResponse]: # type: igno
metric_values[metric_name] = assert_not_none(
first_value(
met.compute(
predictions=ds["postprocessed_prediction"],
predictions=self._get_predictions_from_ds(),
references=ds[self.config.columns.label],
**extra_kwargs,
)
Expand Down Expand Up @@ -158,6 +133,25 @@ def module_to_api_response(res: List[MetricsModuleResponse]) -> List[MetricsAPIR
res_with_plot = MetricsAPIResponse(**metrics_res.dict(), ece_plot=fig)
return [res_with_plot]

def make_probabilities(self) -> np.ndarray:
"""Make probabilities from dataset columns.

Returns:
Array with shape [len(dataset), num_classes] with probabilities.
"""
ds = self.get_dataset_split()
num_classes = self.get_dataset_split_manager().get_num_classes(labels_only=True)

probs = np.zeros([len(ds), num_classes])
for idx, (confidences, predictions) in enumerate(
zip(
self._get_confidences_from_ds(),
ds[DatasetColumn.model_predictions],
)
):
probs[idx] = np.array(confidences)[predictions]
return probs


class MetricsPerFilterModule(AggregationModule[AzimuthConfig]):
"""Computes the metrics for each filter."""
Expand Down
12 changes: 6 additions & 6 deletions azimuth/modules/model_performance/outcome_count.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def get_outcome_count_per_class(
"""
outcome_count_per_class: Dict[Tuple[str, OutcomeName], int] = defaultdict(int)

for utterance_class, outcome in zip(ds[dataset_column], ds[DatasetColumn.outcome]):
for utterance_class, outcome in zip(ds[dataset_column], self._get_outcomes_from_ds()):
outcome_count_per_class[(dm.get_class_names()[utterance_class], outcome)] += 1

return sorted_by_utterance_count_with_last(
Expand All @@ -73,7 +73,7 @@ def get_outcome_count_per_tag(
all_tags = dm.get_tags(
indices=assert_is_list(ds[DatasetColumn.row_idx]), table_key=self._get_table_key()
)
for utterance_tags, outcome in zip(all_tags, ds[DatasetColumn.outcome]):
for utterance_tags, outcome in zip(all_tags, self._get_outcomes_from_ds()):
no_tag = True
for filter_, tagged in utterance_tags.items():
if tagged and filter_ in filters[:-1]:
Expand All @@ -86,8 +86,7 @@ def get_outcome_count_per_tag(
self.get_outcome_count(outcome_count_per_tag, filters), -1
)

@classmethod
def get_outcome_count_per_outcome(cls, ds: Dataset) -> List[OutcomeCountPerFilterValue]:
def get_outcome_count_per_outcome(self, ds: Dataset) -> List[OutcomeCountPerFilterValue]:
"""Compute outcome count per outcome.

Args:
Expand All @@ -97,7 +96,7 @@ def get_outcome_count_per_outcome(cls, ds: Dataset) -> List[OutcomeCountPerFilte
List of Outcome Count for each outcome.

"""
outcome_count = defaultdict(int, Counter(ds[DatasetColumn.outcome]))
outcome_count = defaultdict(int, Counter(self._get_outcomes_from_ds()))
empty_outcome_count = {outcome: 0 for outcome in OutcomeName}

metrics = [
Expand Down Expand Up @@ -187,10 +186,11 @@ def compute_on_dataset_split(self) -> List[OutcomeCountPerThresholdResponse]: #
),
)
outcomes = outcomes_mod.compute_on_dataset_split()
postprocessed_outcomes = [outcome.postprocessed_outcome for outcome in outcomes]
result.append(
OutcomeCountPerThresholdValue(
threshold=th,
outcome_count=Counter(outcomes),
outcome_count=Counter(postprocessed_outcomes),
)
)
return [OutcomeCountPerThresholdResponse(outcome_count_all_thresholds=result)]
Loading