Skip to content

Commit 7fbf3c4

Browse files
authored
Add GaussianSmooth as antialiasing filter in Resize (#4249)
Signed-off-by: Can Zhao <canz@nvidia.com>
1 parent 5536cc3 commit 7fbf3c4

File tree

2 files changed

+69
-6
lines changed

2 files changed

+69
-6
lines changed

monai/transforms/spatial/array.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from monai.networks.layers import AffineTransform, GaussianFilter, grid_pull
2727
from monai.networks.utils import meshgrid_ij, normalize_transform
2828
from monai.transforms.croppad.array import CenterSpatialCrop, Pad
29+
from monai.transforms.intensity.array import GaussianSmooth
2930
from monai.transforms.transform import Randomizable, RandomizableTransform, ThreadUnsafe, Transform
3031
from monai.transforms.utils import (
3132
create_control_grid,
@@ -622,6 +623,15 @@ class Resize(Transform):
622623
align_corners: This only has an effect when mode is
623624
'linear', 'bilinear', 'bicubic' or 'trilinear'. Default: None.
624625
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html
626+
anti_aliasing: bool
627+
Whether to apply a Gaussian filter to smooth the image prior
628+
to downsampling. It is crucial to filter when downsampling
629+
the image to avoid aliasing artifacts. See also ``skimage.transform.resize``
630+
anti_aliasing_sigma: {float, tuple of floats}, optional
631+
Standard deviation for Gaussian filtering used when anti-aliasing.
632+
By default, this value is chosen as (s - 1) / 2 where s is the
633+
downsampling factor, where s > 1. For the up-size case, s < 1, no
634+
anti-aliasing is performed prior to rescaling.
625635
"""
626636

627637
backend = [TransformBackends.TORCH]
@@ -632,17 +642,23 @@ def __init__(
632642
size_mode: str = "all",
633643
mode: Union[InterpolateMode, str] = InterpolateMode.AREA,
634644
align_corners: Optional[bool] = None,
645+
anti_aliasing: bool = False,
646+
anti_aliasing_sigma: Union[Sequence[float], float, None] = None,
635647
) -> None:
636648
self.size_mode = look_up_option(size_mode, ["all", "longest"])
637649
self.spatial_size = spatial_size
638650
self.mode: InterpolateMode = look_up_option(mode, InterpolateMode)
639651
self.align_corners = align_corners
652+
self.anti_aliasing = anti_aliasing
653+
self.anti_aliasing_sigma = anti_aliasing_sigma
640654

641655
def __call__(
642656
self,
643657
img: NdarrayOrTensor,
644658
mode: Optional[Union[InterpolateMode, str]] = None,
645659
align_corners: Optional[bool] = None,
660+
anti_aliasing: Optional[bool] = None,
661+
anti_aliasing_sigma: Union[Sequence[float], float, None] = None,
646662
) -> NdarrayOrTensor:
647663
"""
648664
Args:
@@ -653,11 +669,23 @@ def __call__(
653669
align_corners: This only has an effect when mode is
654670
'linear', 'bilinear', 'bicubic' or 'trilinear'. Defaults to ``self.align_corners``.
655671
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html
672+
anti_aliasing: bool, optional
673+
Whether to apply a Gaussian filter to smooth the image prior
674+
to downsampling. It is crucial to filter when downsampling
675+
the image to avoid aliasing artifacts. See also ``skimage.transform.resize``
676+
anti_aliasing_sigma: {float, tuple of floats}, optional
677+
Standard deviation for Gaussian filtering used when anti-aliasing.
678+
By default, this value is chosen as (s - 1) / 2 where s is the
679+
downsampling factor, where s > 1. For the up-size case, s < 1, no
680+
anti-aliasing is performed prior to rescaling.
656681
657682
Raises:
658683
ValueError: When ``self.spatial_size`` length is less than ``img`` spatial dimensions.
659684
660685
"""
686+
anti_aliasing = self.anti_aliasing if anti_aliasing is None else anti_aliasing
687+
anti_aliasing_sigma = self.anti_aliasing_sigma if anti_aliasing_sigma is None else anti_aliasing_sigma
688+
661689
img_, *_ = convert_data_type(img, torch.Tensor, dtype=torch.float)
662690
if self.size_mode == "all":
663691
input_ndim = img_.ndim - 1 # spatial ndim
@@ -677,6 +705,20 @@ def __call__(
677705
raise ValueError("spatial_size must be an int number if size_mode is 'longest'.")
678706
scale = self.spatial_size / max(img_size)
679707
spatial_size_ = tuple(int(round(s * scale)) for s in img_size)
708+
709+
if anti_aliasing and any(x < y for x, y in zip(spatial_size_, img_.shape[1:])):
710+
factors = torch.div(torch.Tensor(list(img_.shape[1:])), torch.Tensor(spatial_size_))
711+
if anti_aliasing_sigma is None:
712+
# if sigma is not given, use the default sigma in skimage.transform.resize
713+
anti_aliasing_sigma = torch.maximum(torch.zeros(factors.shape), (factors - 1) / 2).tolist()
714+
else:
715+
# if sigma is given, use the given value for downsampling axis
716+
anti_aliasing_sigma = list(ensure_tuple_rep(anti_aliasing_sigma, len(spatial_size_)))
717+
for axis in range(len(spatial_size_)):
718+
anti_aliasing_sigma[axis] = anti_aliasing_sigma[axis] * int(factors[axis] > 1)
719+
anti_aliasing_filter = GaussianSmooth(sigma=anti_aliasing_sigma)
720+
img_ = anti_aliasing_filter(img_)
721+
680722
resized = torch.nn.functional.interpolate(
681723
input=img_.unsqueeze(0),
682724
size=spatial_size_,

tests/test_resize.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import numpy as np
1515
import skimage.transform
16+
import torch
1617
from parameterized import parameterized
1718

1819
from monai.transforms import Resize
@@ -24,6 +25,10 @@
2425

2526
TEST_CASE_2 = [{"spatial_size": 6, "mode": "trilinear", "align_corners": True}, (2, 4, 6)]
2627

28+
TEST_CASE_3 = [{"spatial_size": 15, "anti_aliasing": True}, (6, 10, 15)]
29+
30+
TEST_CASE_4 = [{"spatial_size": 6, "anti_aliasing": True, "anti_aliasing_sigma": 2.0}, (2, 4, 6)]
31+
2732

2833
class TestResize(NumpyImageTestCase2D):
2934
def test_invalid_inputs(self):
@@ -36,28 +41,44 @@ def test_invalid_inputs(self):
3641
resize(self.imt[0])
3742

3843
@parameterized.expand(
39-
[((32, -1), "area"), ((32, 32), "area"), ((32, 32, 32), "trilinear"), ((256, 256), "bilinear")]
44+
[
45+
((32, -1), "area", True),
46+
((32, 32), "area", False),
47+
((32, 32, 32), "trilinear", True),
48+
((256, 256), "bilinear", False),
49+
]
4050
)
41-
def test_correct_results(self, spatial_size, mode):
42-
resize = Resize(spatial_size, mode=mode)
51+
def test_correct_results(self, spatial_size, mode, anti_aliasing):
52+
resize = Resize(spatial_size, mode=mode, anti_aliasing=anti_aliasing)
4353
_order = 0
4454
if mode.endswith("linear"):
4555
_order = 1
4656
if spatial_size == (32, -1):
4757
spatial_size = (32, 64)
4858
expected = [
4959
skimage.transform.resize(
50-
channel, spatial_size, order=_order, clip=False, preserve_range=False, anti_aliasing=False
60+
channel, spatial_size, order=_order, clip=False, preserve_range=False, anti_aliasing=anti_aliasing
5161
)
5262
for channel in self.imt[0]
5363
]
5464

5565
expected = np.stack(expected).astype(np.float32)
5666
for p in TEST_NDARRAYS:
5767
out = resize(p(self.imt[0]))
58-
assert_allclose(out, expected, type_test=False, atol=0.9)
68+
if not anti_aliasing:
69+
assert_allclose(out, expected, type_test=False, atol=0.9)
70+
else:
71+
# skimage uses reflect padding for anti-aliasing filter.
72+
# Our implementation reuses GaussianSmooth() as anti-aliasing filter, which uses zero padding instead.
73+
# Thus their results near the image boundary will be different.
74+
if isinstance(out, torch.Tensor):
75+
out = out.cpu().detach().numpy()
76+
good = np.sum(np.isclose(expected, out, atol=0.9))
77+
self.assertLessEqual(
78+
np.abs(good - expected.size) / float(expected.size), 0.2, "at most 20 percent mismatch "
79+
)
5980

60-
@parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2])
81+
@parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4])
6182
def test_longest_shape(self, input_param, expected_shape):
6283
input_data = np.random.randint(0, 2, size=[3, 4, 7, 10])
6384
input_param["size_mode"] = "longest"

0 commit comments

Comments
 (0)