Skip to content

Commit 026f840

Browse files
committed
Merge branch 'dev' into 4108-pixelshuffle-scriptable
Signed-off-by: Wenqi Li <wenqil@nvidia.com>
2 parents 4ab097c + 9c0a538 commit 026f840

File tree

16 files changed

+39
-140
lines changed

16 files changed

+39
-140
lines changed

.github/workflows/cron.yml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ jobs:
1515
runs-on: [self-hosted, linux, x64, common]
1616
strategy:
1717
matrix:
18-
pytorch-version: [1.6.0, 1.7.1, 1.8.1, 1.9.1, latest]
18+
pytorch-version: [1.7.1, 1.8.1, 1.9.1, 1.10.2, latest]
1919
steps:
2020
- uses: actions/checkout@v2
2121
- name: Install the dependencies
@@ -24,15 +24,15 @@ jobs:
2424
python -m pip install --upgrade pip wheel
2525
python -m pip uninstall -y torch torchvision
2626
if [ ${{ matrix.pytorch-version }} == "latest" ]; then
27-
python -m pip install torch torchvision
28-
elif [ ${{ matrix.pytorch-version }} == "1.6.0" ]; then
29-
python -m pip install torch==1.6.0 torchvision==0.7.0
27+
python -m pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu113
3028
elif [ ${{ matrix.pytorch-version }} == "1.7.1" ]; then
31-
python -m pip install torch==1.7.1 torchvision==0.8.2
29+
python -m pip install torch==1.7.1 torchvision==0.8.2 --extra-index-url https://download.pytorch.org/whl/cu113
3230
elif [ ${{ matrix.pytorch-version }} == "1.8.1" ]; then
33-
python -m pip install torch==1.8.1 torchvision==0.9.1
31+
python -m pip install torch==1.8.1 torchvision==0.9.1 --extra-index-url https://download.pytorch.org/whl/cu113
3432
elif [ ${{ matrix.pytorch-version }} == "1.9.1" ]; then
35-
python -m pip install torch==1.9.1 torchvision==0.10.1
33+
python -m pip install torch==1.9.1 torchvision==0.10.1 --extra-index-url https://download.pytorch.org/whl/cu113
34+
elif [ ${{ matrix.pytorch-version }} == "1.10.2" ]; then
35+
python -m pip install torch==1.10.2 torchvision==0.11.3 --extra-index-url https://download.pytorch.org/whl/cu113
3636
fi
3737
python -m pip install -r requirements-dev.txt
3838
python -m pip list

.github/workflows/pythonapp-gpu.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ jobs:
5050
pytorch: "-h"
5151
base: "nvcr.io/nvidia/pytorch:22.03-py3"
5252
- environment: PT110+CUDA102
53-
pytorch: "torch==1.10.1 torchvision==0.11.2"
53+
pytorch: "torch==1.10.2 torchvision==0.11.3"
5454
base: "nvcr.io/nvidia/cuda:10.2-devel-ubuntu18.04"
5555
- environment: PT111+CUDA102
5656
pytorch: "torch==1.11.0 torchvision==0.12.0"

.github/workflows/pythonapp-min.yml

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ jobs:
119119
strategy:
120120
fail-fast: false
121121
matrix:
122-
pytorch-version: [1.6.0, 1.7.1, 1.8.1, 1.9.1, 1.10.1, latest]
122+
pytorch-version: [1.7.1, 1.8.1, 1.9.1, 1.10.2, latest]
123123
timeout-minutes: 40
124124
steps:
125125
- uses: actions/checkout@v2
@@ -148,16 +148,14 @@ jobs:
148148
# min. requirements
149149
if [ ${{ matrix.pytorch-version }} == "latest" ]; then
150150
python -m pip install torch
151-
elif [ ${{ matrix.pytorch-version }} == "1.6.0" ]; then
152-
python -m pip install torch==1.6.0
153151
elif [ ${{ matrix.pytorch-version }} == "1.7.1" ]; then
154152
python -m pip install torch==1.7.1
155153
elif [ ${{ matrix.pytorch-version }} == "1.8.1" ]; then
156154
python -m pip install torch==1.8.1
157155
elif [ ${{ matrix.pytorch-version }} == "1.9.1" ]; then
158156
python -m pip install torch==1.9.1
159-
elif [ ${{ matrix.pytorch-version }} == "1.10.1" ]; then
160-
python -m pip install torch==1.10.1
157+
elif [ ${{ matrix.pytorch-version }} == "1.10.2" ]; then
158+
python -m pip install torch==1.10.2
161159
fi
162160
python -m pip install -r requirements-min.txt
163161
python -m pip list

.github/workflows/pythonapp.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ jobs:
137137
# install the latest pytorch for testing
138138
# however, "pip install monai*.tar.gz" will build cpp/cuda with an isolated
139139
# fresh torch installation according to pyproject.toml
140-
python -m pip install torch>=1.6 torchvision
140+
python -m pip install torch>=1.7 torchvision
141141
- name: Check packages
142142
run: |
143143
pip uninstall monai

monai/data/torchscript_utils.py

Lines changed: 5 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818

1919
from monai.config import get_config_values
2020
from monai.utils import JITMetadataKeys
21-
from monai.utils.module import pytorch_after
2221

2322
METADATA_FILENAME = "metadata.json"
2423

@@ -80,19 +79,10 @@ def save_net_with_metadata(
8079

8180
json_data = json.dumps(metadict)
8281

83-
# Pytorch>1.6 can use dictionaries directly, otherwise need to use special map object
84-
if pytorch_after(1, 7):
85-
extra_files = {METADATA_FILENAME: json_data.encode()}
82+
extra_files = {METADATA_FILENAME: json_data.encode()}
8683

87-
if more_extra_files is not None:
88-
extra_files.update(more_extra_files)
89-
else:
90-
extra_files = torch._C.ExtraFilesMap() # type:ignore[attr-defined]
91-
extra_files[METADATA_FILENAME] = json_data.encode()
92-
93-
if more_extra_files is not None:
94-
for k, v in more_extra_files.items():
95-
extra_files[k] = v
84+
if more_extra_files is not None:
85+
extra_files.update(more_extra_files)
9686

9787
if isinstance(filename_prefix_or_stream, str):
9888
filename_no_ext, ext = os.path.splitext(filename_prefix_or_stream)
@@ -123,16 +113,8 @@ def load_net_with_metadata(
123113
Returns:
124114
Triple containing loaded object, metadata dict, and extra files dict containing other file data if present
125115
"""
126-
# Pytorch>1.6 can use dictionaries directly, otherwise need to use special map object
127-
if pytorch_after(1, 7):
128-
extra_files = {f: "" for f in more_extra_files}
129-
extra_files[METADATA_FILENAME] = ""
130-
else:
131-
extra_files = torch._C.ExtraFilesMap() # type:ignore[attr-defined]
132-
extra_files[METADATA_FILENAME] = ""
133-
134-
for f in more_extra_files:
135-
extra_files[f] = ""
116+
extra_files = {f: "" for f in more_extra_files}
117+
extra_files[METADATA_FILENAME] = ""
136118

137119
jit_obj = torch.jit.load(filename_prefix_or_stream, map_location, extra_files)
138120

monai/engines/trainer.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from monai.engines.workflow import Workflow
2727
from monai.inferers import Inferer, SimpleInferer
2828
from monai.transforms import Transform
29-
from monai.utils import min_version, optional_import, pytorch_after
29+
from monai.utils import min_version, optional_import
3030
from monai.utils.enums import CommonKeys as Keys
3131

3232
if TYPE_CHECKING:
@@ -193,11 +193,7 @@ def _compute_pred_loss():
193193
engine.fire_event(IterationEvents.LOSS_COMPLETED)
194194

195195
self.network.train()
196-
# `set_to_none` only work from PyTorch 1.7.0
197-
if not pytorch_after(1, 7):
198-
self.optimizer.zero_grad()
199-
else:
200-
self.optimizer.zero_grad(set_to_none=self.optim_set_to_none)
196+
self.optimizer.zero_grad(set_to_none=self.optim_set_to_none)
201197

202198
if self.amp and self.scaler is not None:
203199
with torch.cuda.amp.autocast():
@@ -366,11 +362,7 @@ def _iteration(
366362
# Train Discriminator
367363
d_total_loss = torch.zeros(1)
368364
for _ in range(self.d_train_steps):
369-
# `set_to_none` only work from PyTorch 1.7.0
370-
if not pytorch_after(1, 7):
371-
self.d_optimizer.zero_grad()
372-
else:
373-
self.d_optimizer.zero_grad(set_to_none=self.optim_set_to_none)
365+
self.d_optimizer.zero_grad(set_to_none=self.optim_set_to_none)
374366
dloss = self.d_loss_function(g_output, d_input)
375367
dloss.backward()
376368
self.d_optimizer.step()
@@ -385,10 +377,7 @@ def _iteration(
385377
non_blocking=engine.non_blocking, # type: ignore
386378
)
387379
g_output = self.g_inferer(g_input, self.g_network)
388-
if not pytorch_after(1, 7):
389-
self.g_optimizer.zero_grad()
390-
else:
391-
self.g_optimizer.zero_grad(set_to_none=self.optim_set_to_none)
380+
self.g_optimizer.zero_grad(set_to_none=self.optim_set_to_none)
392381
g_loss = self.g_loss_function(g_output)
393382
g_loss.backward()
394383
self.g_optimizer.step()

monai/networks/layers/simplelayers.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,11 @@
2020

2121
from monai.networks.layers.convutils import gaussian_1d
2222
from monai.networks.layers.factories import Conv
23-
from monai.utils import (
24-
ChannelMatching,
25-
InvalidPyTorchVersionError,
26-
SkipMode,
27-
look_up_option,
28-
optional_import,
29-
pytorch_after,
30-
)
23+
from monai.utils import ChannelMatching, SkipMode, look_up_option, optional_import, pytorch_after
3124
from monai.utils.misc import issequenceiterable
3225

3326
_C, _ = optional_import("monai._C")
34-
if pytorch_after(1, 7):
35-
fft, _ = optional_import("torch.fft")
27+
fft, _ = optional_import("torch.fft")
3628

3729
__all__ = [
3830
"ChannelPad",
@@ -377,7 +369,6 @@ def _make_coeffs(window_length, order):
377369
class HilbertTransform(nn.Module):
378370
"""
379371
Determine the analytical signal of a Tensor along a particular axis.
380-
Requires PyTorch 1.7.0+ and the PyTorch FFT module (which is not included in NVIDIA PyTorch Release 20.10).
381372
382373
Args:
383374
axis: Axis along which to apply Hilbert transform. Default 2 (first spatial dimension).
@@ -386,9 +377,6 @@ class HilbertTransform(nn.Module):
386377

387378
def __init__(self, axis: int = 2, n: Union[int, None] = None) -> None:
388379

389-
if not pytorch_after(1, 7):
390-
raise InvalidPyTorchVersionError("1.7.0", self.__class__.__name__)
391-
392380
super().__init__()
393381
self.axis = axis
394382
self.n = n

monai/networks/utils.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -504,7 +504,6 @@ def convert_to_torchscript(
504504
filename_or_obj: if not None, specify a file-like object (has to implement write and flush)
505505
or a string containing a file path name to save the TorchScript model.
506506
extra_files: map from filename to contents which will be stored as part of the save model file.
507-
works for PyTorch 1.7 or later.
508507
for more details: https://pytorch.org/docs/stable/generated/torch.jit.save.html.
509508
verify: whether to verify the input and output of TorchScript model.
510509
if `filename_or_obj` is not None, load the saved TorchScript model and verify.
@@ -521,10 +520,7 @@ def convert_to_torchscript(
521520
with torch.no_grad():
522521
script_module = torch.jit.script(model, **kwargs)
523522
if filename_or_obj is not None:
524-
if not pytorch_after(1, 7):
525-
torch.jit.save(m=script_module, f=filename_or_obj)
526-
else:
527-
torch.jit.save(m=script_module, f=filename_or_obj, _extra_files=extra_files)
523+
torch.jit.save(m=script_module, f=filename_or_obj, _extra_files=extra_files)
528524

529525
if verify:
530526
if device is None:

monai/transforms/intensity/array.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,12 @@
3030
from monai.transforms.utils import Fourier, equalize_hist, is_positive, rescale_array
3131
from monai.transforms.utils_pytorch_numpy_unification import clip, percentile, where
3232
from monai.utils import (
33-
InvalidPyTorchVersionError,
3433
convert_data_type,
3534
convert_to_dst_type,
3635
ensure_tuple,
3736
ensure_tuple_rep,
3837
ensure_tuple_size,
3938
fall_back_tuple,
40-
pytorch_after,
4139
)
4240
from monai.utils.deprecate_utils import deprecated_arg
4341
from monai.utils.enums import TransformBackends
@@ -1085,7 +1083,6 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
10851083
class DetectEnvelope(Transform):
10861084
"""
10871085
Find the envelope of the input data along the requested axis using a Hilbert transform.
1088-
Requires PyTorch 1.7.0+ and the PyTorch FFT module (which is not included in NVIDIA PyTorch Release 20.10).
10891086
10901087
Args:
10911088
axis: Axis along which to detect the envelope. Default 1, i.e. the first spatial dimension.
@@ -1098,9 +1095,6 @@ class DetectEnvelope(Transform):
10981095

10991096
def __init__(self, axis: int = 1, n: Union[int, None] = None) -> None:
11001097

1101-
if not pytorch_after(1, 7):
1102-
raise InvalidPyTorchVersionError("1.7.0", self.__class__.__name__)
1103-
11041098
if axis < 0:
11051099
raise ValueError("axis must be zero or positive.")
11061100

monai/transforms/utils_pytorch_numpy_unification.py

Lines changed: 7 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import torch
1616

1717
from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor
18-
from monai.utils.misc import ensure_tuple, is_module_ver_at_least
18+
from monai.utils.misc import is_module_ver_at_least
1919
from monai.utils.type_conversion import convert_data_type, convert_to_dst_type
2020

2121
__all__ = [
@@ -54,31 +54,12 @@ def allclose(a: NdarrayTensor, b: NdarrayOrTensor, rtol=1e-5, atol=1e-8, equal_n
5454

5555

5656
def moveaxis(x: NdarrayOrTensor, src: Union[int, Sequence[int]], dst: Union[int, Sequence[int]]) -> NdarrayOrTensor:
57-
"""`moveaxis` for pytorch and numpy, using `permute` for pytorch version < 1.7"""
57+
"""`moveaxis` for pytorch and numpy"""
5858
if isinstance(x, torch.Tensor):
59-
if hasattr(torch, "movedim"): # `movedim` is new in torch 1.7.0
60-
# torch.moveaxis is a recent alias since torch 1.8.0
61-
return torch.movedim(x, src, dst) # type: ignore
62-
return _moveaxis_with_permute(x, src, dst)
59+
return torch.movedim(x, src, dst) # type: ignore
6360
return np.moveaxis(x, src, dst)
6461

6562

66-
def _moveaxis_with_permute(
67-
x: torch.Tensor, src: Union[int, Sequence[int]], dst: Union[int, Sequence[int]]
68-
) -> torch.Tensor:
69-
# get original indices
70-
indices = list(range(x.ndim))
71-
len_indices = len(indices)
72-
for s, d in zip(ensure_tuple(src), ensure_tuple(dst)):
73-
# make src and dst positive
74-
# remove desired index and insert it in new position
75-
pos_s = len_indices + s if s < 0 else s
76-
pos_d = len_indices + d if d < 0 else d
77-
indices.pop(pos_s)
78-
indices.insert(pos_d, pos_s)
79-
return x.permute(indices)
80-
81-
8263
def in1d(x, y):
8364
"""`np.in1d` with equivalent implementation for torch."""
8465
if isinstance(x, np.ndarray):
@@ -101,18 +82,15 @@ def percentile(
10182
) -> Union[NdarrayOrTensor, float, int]:
10283
"""`np.percentile` with equivalent implementation for torch.
10384
104-
Pytorch uses `quantile`, but this functionality is only available from v1.7.
105-
For earlier methods, we calculate it ourselves. This doesn't do interpolation,
106-
so is the equivalent of ``numpy.percentile(..., interpolation="nearest")``.
107-
For more details, please refer to:
85+
Pytorch uses `quantile`. For more details please refer to:
10886
https://pytorch.org/docs/stable/generated/torch.quantile.html.
10987
https://numpy.org/doc/stable/reference/generated/numpy.percentile.html.
11088
11189
Args:
11290
x: input data
11391
q: percentile to compute (should in range 0 <= q <= 100)
11492
dim: the dim along which the percentiles are computed. default is to compute the percentile
115-
along a flattened version of the array. only work for numpy array or Tensor with PyTorch >= 1.7.0.
93+
along a flattened version of the array.
11694
keepdim: whether the output data has dim retained or not.
11795
kwargs: if `x` is numpy array, additional args for `np.percentile`, more details:
11896
https://numpy.org/doc/stable/reference/generated/numpy.percentile.html.
@@ -130,18 +108,7 @@ def percentile(
130108
result = np.percentile(x, q, axis=dim, keepdims=keepdim, **kwargs)
131109
else:
132110
q = torch.tensor(q, device=x.device)
133-
if hasattr(torch, "quantile"): # `quantile` is new in torch 1.7.0
134-
result = torch.quantile(x, q / 100.0, dim=dim, keepdim=keepdim)
135-
else:
136-
# Note that ``kthvalue()`` works one-based, i.e., the first sorted value
137-
# corresponds to k=1, not k=0. Thus, we need the `1 +`.
138-
k = 1 + (0.01 * q * (x.numel() - 1)).round().int()
139-
if k.numel() > 1:
140-
r = [x.view(-1).kthvalue(int(_k)).values.item() for _k in k]
141-
result = torch.tensor(r, device=x.device)
142-
else:
143-
result = x.view(-1).kthvalue(int(k)).values.item()
144-
111+
result = torch.quantile(x, q / 100.0, dim=dim, keepdim=keepdim)
145112
return result
146113

147114

@@ -277,8 +244,6 @@ def any_np_pt(x: NdarrayOrTensor, axis: Union[int, Sequence[int]]) -> NdarrayOrT
277244
def maximum(a: NdarrayOrTensor, b: NdarrayOrTensor) -> NdarrayOrTensor:
278245
"""`np.maximum` with equivalent implementation for torch.
279246
280-
`torch.maximum` only available from pt>1.6, else use `torch.stack` and `torch.max`.
281-
282247
Args:
283248
a: first array/tensor
284249
b: second array/tensor
@@ -287,10 +252,7 @@ def maximum(a: NdarrayOrTensor, b: NdarrayOrTensor) -> NdarrayOrTensor:
287252
Element-wise maximum between two arrays/tensors.
288253
"""
289254
if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor):
290-
# is torch and has torch.maximum (pt>1.6)
291-
if hasattr(torch, "maximum"): # `maximum` is new in torch 1.7.0
292-
return torch.maximum(a, b)
293-
return torch.stack((a, b)).max(dim=0)[0]
255+
return torch.maximum(a, b)
294256
return np.maximum(a, b)
295257

296258

0 commit comments

Comments
 (0)