Skip to content

Commit ee0a368

Browse files
Nic-Mawyli
authored andcommitted
2453 Enhance ToTensor transform to support data with no dim (#2454)
* [DLMED] enhance ToTensor Signed-off-by: Nic Ma <nma@nvidia.com> * [DLMED] update according to comments Signed-off-by: Nic Ma <nma@nvidia.com> Signed-off-by: Wenqi Li <wenqil@nvidia.com>
1 parent c94a45e commit ee0a368

File tree

2 files changed

+42
-2
lines changed

2 files changed

+42
-2
lines changed

monai/transforms/utility/array.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from monai.config import DtypeLike, NdarrayTensor
2626
from monai.transforms.transform import Randomizable, Transform
2727
from monai.transforms.utils import extreme_points_to_image, get_extreme_points, map_binary_to_indices
28-
from monai.utils import ensure_tuple, min_version, optional_import
28+
from monai.utils import ensure_tuple, issequenceiterable, min_version, optional_import
2929

3030
PILImageImage, has_pil = optional_import("PIL.Image", name="Image")
3131
pil_image_fromarray, _ = optional_import("PIL.Image", name="fromarray")
@@ -320,7 +320,12 @@ def __call__(self, img) -> torch.Tensor:
320320
"""
321321
if isinstance(img, torch.Tensor):
322322
return img.contiguous()
323-
return torch.as_tensor(np.ascontiguousarray(img))
323+
if issequenceiterable(img):
324+
# numpy array with 0 dims is also sequence iterable
325+
if not (isinstance(img, np.ndarray) and img.ndim == 0):
326+
# `ascontiguousarray` will add 1 dim if img has no dim, so we only apply on data with dims
327+
img = np.ascontiguousarray(img)
328+
return torch.as_tensor(img)
324329

325330

326331
class ToNumpy(Transform):

tests/test_to_tensor.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright 2020 - 2021 MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import unittest
13+
14+
import numpy as np
15+
import torch
16+
17+
from monai.transforms import ToTensor
18+
19+
20+
class TestToTensor(unittest.TestCase):
21+
def test_array_input(self):
22+
for test_data in ([[1, 2], [3, 4]], np.array([[1, 2], [3, 4]]), torch.as_tensor([[1, 2], [3, 4]])):
23+
result = ToTensor()(test_data)
24+
torch.testing.assert_allclose(result, test_data)
25+
self.assertTupleEqual(result.shape, (2, 2))
26+
27+
def test_single_input(self):
28+
for test_data in (5, np.asarray(5), torch.tensor(5)):
29+
result = ToTensor()(test_data)
30+
torch.testing.assert_allclose(result, test_data)
31+
self.assertEqual(result.ndim, 0)
32+
33+
34+
if __name__ == "__main__":
35+
unittest.main()

0 commit comments

Comments
 (0)