1515import torch
1616
1717from monai .config .type_definitions import NdarrayOrTensor , NdarrayTensor
18- from monai .utils .misc import ensure_tuple , is_module_ver_at_least
18+ from monai .utils .misc import is_module_ver_at_least
1919from monai .utils .type_conversion import convert_data_type , convert_to_dst_type
2020
2121__all__ = [
@@ -54,31 +54,12 @@ def allclose(a: NdarrayTensor, b: NdarrayOrTensor, rtol=1e-5, atol=1e-8, equal_n
5454
5555
5656def moveaxis (x : NdarrayOrTensor , src : Union [int , Sequence [int ]], dst : Union [int , Sequence [int ]]) -> NdarrayOrTensor :
57- """`moveaxis` for pytorch and numpy, using `permute` for pytorch version < 1.7 """
57+ """`moveaxis` for pytorch and numpy"""
5858 if isinstance (x , torch .Tensor ):
59- if hasattr (torch , "movedim" ): # `movedim` is new in torch 1.7.0
60- # torch.moveaxis is a recent alias since torch 1.8.0
61- return torch .movedim (x , src , dst ) # type: ignore
62- return _moveaxis_with_permute (x , src , dst )
59+ return torch .movedim (x , src , dst ) # type: ignore
6360 return np .moveaxis (x , src , dst )
6461
6562
66- def _moveaxis_with_permute (
67- x : torch .Tensor , src : Union [int , Sequence [int ]], dst : Union [int , Sequence [int ]]
68- ) -> torch .Tensor :
69- # get original indices
70- indices = list (range (x .ndim ))
71- len_indices = len (indices )
72- for s , d in zip (ensure_tuple (src ), ensure_tuple (dst )):
73- # make src and dst positive
74- # remove desired index and insert it in new position
75- pos_s = len_indices + s if s < 0 else s
76- pos_d = len_indices + d if d < 0 else d
77- indices .pop (pos_s )
78- indices .insert (pos_d , pos_s )
79- return x .permute (indices )
80-
81-
8263def in1d (x , y ):
8364 """`np.in1d` with equivalent implementation for torch."""
8465 if isinstance (x , np .ndarray ):
@@ -101,18 +82,15 @@ def percentile(
10182) -> Union [NdarrayOrTensor , float , int ]:
10283 """`np.percentile` with equivalent implementation for torch.
10384
104- Pytorch uses `quantile`, but this functionality is only available from v1.7.
105- For earlier methods, we calculate it ourselves. This doesn't do interpolation,
106- so is the equivalent of ``numpy.percentile(..., interpolation="nearest")``.
107- For more details, please refer to:
85+ Pytorch uses `quantile`. For more details please refer to:
10886 https://pytorch.org/docs/stable/generated/torch.quantile.html.
10987 https://numpy.org/doc/stable/reference/generated/numpy.percentile.html.
11088
11189 Args:
11290 x: input data
11391 q: percentile to compute (should in range 0 <= q <= 100)
11492 dim: the dim along which the percentiles are computed. default is to compute the percentile
115- along a flattened version of the array. only work for numpy array or Tensor with PyTorch >= 1.7.0.
93+ along a flattened version of the array.
11694 keepdim: whether the output data has dim retained or not.
11795 kwargs: if `x` is numpy array, additional args for `np.percentile`, more details:
11896 https://numpy.org/doc/stable/reference/generated/numpy.percentile.html.
@@ -130,18 +108,7 @@ def percentile(
130108 result = np .percentile (x , q , axis = dim , keepdims = keepdim , ** kwargs )
131109 else :
132110 q = torch .tensor (q , device = x .device )
133- if hasattr (torch , "quantile" ): # `quantile` is new in torch 1.7.0
134- result = torch .quantile (x , q / 100.0 , dim = dim , keepdim = keepdim )
135- else :
136- # Note that ``kthvalue()`` works one-based, i.e., the first sorted value
137- # corresponds to k=1, not k=0. Thus, we need the `1 +`.
138- k = 1 + (0.01 * q * (x .numel () - 1 )).round ().int ()
139- if k .numel () > 1 :
140- r = [x .view (- 1 ).kthvalue (int (_k )).values .item () for _k in k ]
141- result = torch .tensor (r , device = x .device )
142- else :
143- result = x .view (- 1 ).kthvalue (int (k )).values .item ()
144-
111+ result = torch .quantile (x , q / 100.0 , dim = dim , keepdim = keepdim )
145112 return result
146113
147114
@@ -277,8 +244,6 @@ def any_np_pt(x: NdarrayOrTensor, axis: Union[int, Sequence[int]]) -> NdarrayOrT
277244def maximum (a : NdarrayOrTensor , b : NdarrayOrTensor ) -> NdarrayOrTensor :
278245 """`np.maximum` with equivalent implementation for torch.
279246
280- `torch.maximum` only available from pt>1.6, else use `torch.stack` and `torch.max`.
281-
282247 Args:
283248 a: first array/tensor
284249 b: second array/tensor
@@ -287,10 +252,7 @@ def maximum(a: NdarrayOrTensor, b: NdarrayOrTensor) -> NdarrayOrTensor:
287252 Element-wise maximum between two arrays/tensors.
288253 """
289254 if isinstance (a , torch .Tensor ) and isinstance (b , torch .Tensor ):
290- # is torch and has torch.maximum (pt>1.6)
291- if hasattr (torch , "maximum" ): # `maximum` is new in torch 1.7.0
292- return torch .maximum (a , b )
293- return torch .stack ((a , b )).max (dim = 0 )[0 ]
255+ return torch .maximum (a , b )
294256 return np .maximum (a , b )
295257
296258
0 commit comments