Skip to content

Commit fb017ad

Browse files
MetricsLambda and ClassificationReport ability to use metrics_result_mode. (#3531)
Fixes #3513 The main reason the issue #3513 was created particularly for me was due to the ClassificationReport behavior. So i'm creating this PR to add this functionality following up the #3514 PR. --------- Co-authored-by: vfdev <vfdev.5@gmail.com>
1 parent ff116eb commit fb017ad

File tree

4 files changed

+103
-6
lines changed

4 files changed

+103
-6
lines changed

ignite/metrics/classification_report.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import json
2-
from typing import Callable, Collection, Dict, List, Optional, Union
2+
from typing import Callable, Collection, Dict, List, Literal, Optional, Union
33

44
import torch
55

@@ -18,6 +18,7 @@ def ClassificationReport(
1818
device: Union[str, torch.device] = torch.device("cpu"),
1919
is_multilabel: bool = False,
2020
labels: Optional[List[str]] = None,
21+
metrics_result_mode: Literal["flatten", "named", "both"] = "both",
2122
) -> MetricsLambda:
2223
r"""Build a text report showing the main classification metrics. The report resembles in functionality to
2324
`scikit-learn classification_report
@@ -34,6 +35,11 @@ def ClassificationReport(
3435
is_multilabel: If True, the tensors are assumed to be multilabel.
3536
device: optional device specification for internal storage.
3637
labels: Optional list of label indices to include in the report
38+
metrics_result_mode: specifies how to put the computed metrics results into
39+
``engine.state.metrics`` dictionary. Valid values are: "flatten", "named", "both".
40+
- "flatten": if the computed result is a mapping, its keys/values are put directly into the engine state metrics dictionary
41+
- "named": if the computed result is a mapping, the whole mapping is put into the engine state metrics dictionary under the metric name
42+
- "both": combination of "flatten" and "named".
3743
3844
Examples:
3945
@@ -107,6 +113,8 @@ def ClassificationReport(
107113
{'precision': 0.0, 'recall': 0.0, 'f1-score': 0.0}
108114
{'precision': 0.2333..., 'recall': 0.6666..., 'f1-score': 0.3333...}
109115
116+
.. versionchanged:: 0.5.4
117+
added ``metrics_result_mode`` argument.
110118
"""
111119

112120
# setup all the underlying metrics
@@ -144,4 +152,13 @@ def _wrapper(
144152
def _get_label_for_class(idx: int) -> str:
145153
return labels[idx] if labels else str(idx)
146154

147-
return MetricsLambda(_wrapper, recall, precision, fbeta, averaged_recall, averaged_precision, averaged_fbeta)
155+
return MetricsLambda(
156+
_wrapper,
157+
recall,
158+
precision,
159+
fbeta,
160+
averaged_recall,
161+
averaged_precision,
162+
averaged_fbeta,
163+
metrics_result_mode=metrics_result_mode,
164+
)

ignite/metrics/metrics_lambda.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import itertools
2-
from typing import Any, Callable, Optional, Union
2+
from typing import Any, Callable, Optional, Union, Literal
33

44
import torch
55

@@ -24,6 +24,14 @@ class MetricsLambda(Metric):
2424
f: the function that defines the computation
2525
args: Sequence of other metrics or something
2626
else that will be fed to ``f`` as arguments.
27+
skip_unrolling: specifies whether output should be unrolled before being fed to update method. Should be
28+
true for multi-output model, for example, if ``y_pred`` contains multi-ouput as ``(y_pred_a, y_pred_b)``
29+
Alternatively, ``output_transform`` can be used to handle this.
30+
metrics_result_mode: specifies how to put the computed metrics results into
31+
``engine.state.metrics`` dictionary. Valid values are: "flatten", "named", "both".
32+
- "flatten": if the computed result is a mapping, its keys/values are put directly into the engine state metrics dictionary
33+
- "named": if the computed result is a mapping, the whole mapping is put into the engine state metrics dictionary under the metric name
34+
- "both": combination of "flatten" and "named".
2735
kwargs: Sequence of other metrics or something
2836
else that will be fed to ``f`` as keyword arguments.
2937
@@ -88,17 +96,27 @@ def Fbeta(r, p, beta):
8896
assert not aP.is_attached(engine)
8997
# fully attached
9098
assert not precision.is_attached(engine)
99+
100+
.. versionchanged:: 0.5.4
101+
added ``skip_unrolling`` and ``metrics_result_mode`` arguments.
91102
"""
92103

93104
_state_dict_all_req_keys = ("_updated", "args", "kwargs")
94105

95-
def __init__(self, f: Callable, *args: Any, **kwargs: Any) -> None:
106+
def __init__(
107+
self,
108+
f: Callable,
109+
*args: Any,
110+
skip_unrolling: bool = False,
111+
metrics_result_mode: Literal["flatten", "named", "both"] = "both",
112+
**kwargs: Any,
113+
) -> None:
96114
self.function = f
97115
self.args = list(args) # we need args to be a list instead of a tuple for state_dict/load_state_dict feature
98116
self.kwargs = kwargs
99117
self.engine: Optional[Engine] = None
100118
self._updated = False
101-
super().__init__(device="cpu")
119+
super().__init__(device="cpu", metrics_result_mode=metrics_result_mode, skip_unrolling=skip_unrolling)
102120

103121
@reinit__is_reduced
104122
def reset(self) -> None:

tests/ignite/metrics/test_classification_report.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import ignite.distributed as idist
99
from ignite.engine import Engine
10+
from ignite.metrics import MetricsLambda
1011
from ignite.metrics.classification_report import ClassificationReport
1112

1213

@@ -157,6 +158,23 @@ def _test_integration_multiclass(device, output_dict):
157158
_test_multiclass(metric_device, n_classes, output_dict, labels=labels[:n_classes], distributed=True)
158159

159160

161+
@pytest.mark.parametrize(
162+
"metrics_result_mode",
163+
[
164+
"flatten",
165+
"named",
166+
"both",
167+
],
168+
)
169+
def test_metrics_result_mode(metrics_result_mode):
170+
metric = ClassificationReport(output_dict=True, metrics_result_mode=metrics_result_mode)
171+
172+
assert isinstance(metric, MetricsLambda), "ClassificationReport should be an instance of MetricsLambda"
173+
assert (
174+
metric._metrics_result_mode == metrics_result_mode
175+
), f"Expected metrics_result_mode to be {metrics_result_mode}"
176+
177+
160178
def _test_integration_multilabel(device, output_dict):
161179
rank = idist.get_rank()
162180

@@ -197,7 +215,6 @@ def test_compute_multilabel(n_times, available_device):
197215
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
198216
@pytest.mark.skipif(Version(torch.__version__) < Version("1.7.0"), reason="Skip if < 1.7.0")
199217
def test_distrib_nccl_gpu(distributed_context_single_node_nccl):
200-
201218
pytest.skip("Temporarily skip failing test. See https://github.com/pytorch/ignite/pull/3301")
202219
# When run with 2 devices:
203220
# tests/ignite/metrics/test_classification_report.py::test_distrib_nccl_gpu Fatal Python error: Aborted

tests/ignite/metrics/test_metrics_lambda.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,51 @@ def test_load_state_dict():
299299
assert e2 == e
300300

301301

302+
@pytest.mark.parametrize(
303+
"metrics_result_mode",
304+
[
305+
"flatten",
306+
"named",
307+
"both",
308+
],
309+
)
310+
def test_metrics_lambda_result_mode_behavior(metrics_result_mode):
311+
# dummy for now
312+
def dummy_compute_fn(*args, **kwargs):
313+
return {
314+
"precision": 0.5,
315+
"recall": 0.5,
316+
"f1-score": 0.5,
317+
}
318+
319+
class DummyMetric(Metric):
320+
def __init__(self, output_transform=lambda x: x):
321+
super().__init__(output_transform=output_transform, metrics_result_mode=metrics_result_mode)
322+
323+
def reset(self): ...
324+
325+
def update(self, output): ...
326+
327+
def compute(self):
328+
return dummy_compute_fn()
329+
330+
metric_a = MetricsLambda(dummy_compute_fn, metrics_result_mode=metrics_result_mode)
331+
metric_b = DummyMetric()
332+
333+
engine_a = Engine(lambda e, b: b)
334+
metric_a.attach(engine_a, "dummy_metric")
335+
336+
engine_b = Engine(lambda e, b: b)
337+
metric_b.attach(engine_b, "dummy_metric")
338+
339+
state_a = engine_a.run([0], max_epochs=1)
340+
state_b = engine_b.run([0], max_epochs=1)
341+
342+
assert state_a.metrics.keys() == state_b.metrics.keys()
343+
344+
assert state_a.metrics == state_b.metrics
345+
346+
302347
def test_state_metrics():
303348
y_pred = torch.randint(0, 2, size=(15, 10, 4)).float()
304349
y = torch.randint(0, 2, size=(15, 10, 4)).long()

0 commit comments

Comments
 (0)