Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.

Commit cd7d30a

Browse files
Added the MMD Metric and tests (#152)
* Added the MMD Metric and tests * Fixed MMD yy calculations and docs. Signed-off-by: Petru-Daniel Tudosiu <petru.daniel@tudosiu.com>
1 parent 5da0ea9 commit cd7d30a

File tree

3 files changed

+139
-0
lines changed

3 files changed

+139
-0
lines changed

generative/metrics/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12+
from .mmd import MMD
1213
from .ms_ssim import MSSSIM

generative/metrics/mmd.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
from typing import Callable, Optional, Union
12+
13+
import torch
14+
from monai.metrics.regression import RegressionMetric
15+
from monai.utils import MetricReduction
16+
17+
18+
class MMD(RegressionMetric):
19+
"""
20+
Unbiased Maximum Mean Discrepancy (MMD) is a kernel-based method for measuring the similarity between two
21+
distributions. It is a non-negative metric where a smaller value indicates a closer match between the two
22+
distributions.
23+
24+
Gretton, A., et al,, 2012. A kernel two-sample test. The Journal of Machine Learning Research, 13(1), pp.723-773.
25+
26+
Args:
27+
y_transform: Callable to transform the y tensor before computing the metric. It is usually a Gaussian or Laplace
28+
filter, but it can be any function that takes a tensor as input and returns a tensor as output such as a
29+
feature extractor or an Identity function.
30+
y_pred_transform: Callable to transform the y_pred tensor before computing the metric.
31+
reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values, available
32+
reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, ``"mean_channel"``,
33+
`"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction. This parameter is ignored due to
34+
the mathematical formulation of MMD.
35+
get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans). Here
36+
`not_nans` count the number of not nans for the metric, thus its shape equals to the shape of the metric.
37+
This parameter is ignored due to the mathematical formulation of MMD.
38+
39+
"""
40+
41+
def __init__(
42+
self,
43+
y_transform: Optional[Callable] = None,
44+
y_pred_transform: Optional[Callable] = None,
45+
reduction: Union[MetricReduction, str] = MetricReduction.MEAN,
46+
get_not_nans: bool = False,
47+
) -> None:
48+
super().__init__(reduction=reduction, get_not_nans=get_not_nans)
49+
50+
self.y_transform = y_transform
51+
self.y_pred_transform = y_pred_transform
52+
53+
def _compute_metric(self, y: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:
54+
"""
55+
Args:
56+
y: first sample (e.g., the reference image). Its shape is (B,C,W,H) for 2D data and (B,C,W,H,D) for 3D.
57+
y_pred: second sample (e.g., the reconstructed image). It has similar shape as y.
58+
"""
59+
60+
# Beta and Gamma are not calculated since torch.mean is used at return
61+
beta = 1.0
62+
gamma = 2.0
63+
64+
if self.y_transform is not None:
65+
y = self.y_transform(y)
66+
67+
if self.y_pred_transform is not None:
68+
y_pred = self.y_pred_transform(y_pred)
69+
70+
if y_pred.shape != y.shape:
71+
raise ValueError(
72+
f"y_pred and y shapes dont match after being processed by their transforms, received y_pred: {y_pred.shape} and y: {y.shape}"
73+
)
74+
75+
for d in range(len(y.shape) - 1, 1, -1):
76+
y = y.squeeze(dim=d)
77+
y_pred = y_pred.squeeze(dim=d)
78+
79+
y = y.view(y.shape[0], -1)
80+
y_pred = y_pred.view(y_pred.shape[0], -1)
81+
82+
y_y = torch.mm(y, y.t())
83+
y_pred_y_pred = torch.mm(y_pred, y_pred.t())
84+
y_pred_y = torch.mm(y_pred, y.t())
85+
86+
y_y = y_y / y.shape[1]
87+
y_pred_y_pred = y_pred_y_pred / y.shape[1]
88+
y_pred_y = y_pred_y / y.shape[1]
89+
90+
# Ref. 1 Eq. 3 (found under Lemma 6)
91+
return beta * (torch.mean(y_y) + torch.mean(y_pred_y_pred)) - gamma * torch.mean(y_pred_y)

tests/test_compute_mmd_metric.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
13+
import unittest
14+
15+
import numpy as np
16+
import torch
17+
from parameterized import parameterized
18+
19+
from generative.metrics import MMD
20+
21+
TEST_CASES = [
22+
[
23+
{"y_transform": None, "y_pred_transform": None},
24+
{"y": torch.ones([3, 3, 144, 144]), "y_pred": torch.ones([3, 3, 144, 144])},
25+
0.0,
26+
],
27+
[
28+
{"y_transform": None, "y_pred_transform": None},
29+
{"y": torch.ones([3, 3, 144, 144, 144]), "y_pred": torch.ones([3, 3, 144, 144, 144])},
30+
0.0,
31+
],
32+
]
33+
34+
35+
class TestMMDMetric(unittest.TestCase):
36+
@parameterized.expand(TEST_CASES)
37+
def test_results(self, input_param, input_data, expected_val):
38+
results = MMD(**input_param)._compute_metric(**input_data)
39+
np.testing.assert_allclose(results.detach().cpu().numpy(), expected_val, rtol=1e-4)
40+
41+
def test_if_inputs_different_shapes(self):
42+
with self.assertRaises(ValueError):
43+
MMD()(torch.ones([3, 3, 144, 144]), torch.ones([3, 3, 145, 145]))
44+
45+
46+
if __name__ == "__main__":
47+
unittest.main()

0 commit comments

Comments
 (0)