Skip to content

Commit 443fc0c

Browse files
authored
ToMetaTensor and FromMetaTensor transforms (#4115)
to and from meta
1 parent 8544d9e commit 443fc0c

File tree

8 files changed

+334
-7
lines changed

8 files changed

+334
-7
lines changed

docs/source/transforms.rst

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1842,6 +1842,21 @@ Utility (Dict)
18421842
:members:
18431843
:special-members: __call__
18441844

1845+
MetaTensor
1846+
^^^^^^^^^^
1847+
1848+
`ToMetaTensord`
1849+
"""""""""""""""
1850+
.. autoclass:: ToMetaTensord
1851+
:members:
1852+
:special-members: __call__
1853+
1854+
`FromMetaTensord`
1855+
"""""""""""""""""
1856+
.. autoclass:: FromMetaTensord
1857+
:members:
1858+
:special-members: __call__
1859+
18451860
Transform Adaptors
18461861
------------------
18471862
.. automodule:: monai.transforms.adaptors

monai/data/meta_tensor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from __future__ import annotations
1313

1414
import warnings
15+
from copy import deepcopy
1516
from typing import Callable
1617

1718
import torch
@@ -88,6 +89,10 @@ def __init__(self, x, affine: torch.Tensor | None = None, meta: dict | None = No
8889
self.affine = x.affine
8990
else:
9091
self.affine = self.get_default_affine()
92+
93+
# if we are creating a new MetaTensor, then deep copy attributes
94+
if isinstance(x, torch.Tensor) and not isinstance(x, MetaTensor):
95+
self.meta = deepcopy(self.meta)
9196
self.affine = self.affine.to(self.device)
9297

9398
def _copy_attr(self, attribute: str, input_objs: list[MetaObj], default_fn: Callable, deep_copy: bool) -> None:

monai/transforms/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,14 @@
206206
from .inverse_batch_transform import BatchInverseTransform, Decollated, DecollateD, DecollateDict
207207
from .io.array import SUPPORTED_READERS, LoadImage, SaveImage
208208
from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict, SaveImaged, SaveImageD, SaveImageDict
209+
from .meta_utility.dictionary import (
210+
FromMetaTensord,
211+
FromMetaTensorD,
212+
FromMetaTensorDict,
213+
ToMetaTensord,
214+
ToMetaTensorD,
215+
ToMetaTensorDict,
216+
)
209217
from .nvtx import (
210218
Mark,
211219
Markd,
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
"""
12+
A collection of dictionary-based wrappers for moving between MetaTensor types and dictionaries of data.
13+
These can be used to make backwards compatible code.
14+
15+
Class names are ended with 'd' to denote dictionary-based transforms.
16+
"""
17+
18+
from copy import deepcopy
19+
from typing import Dict, Hashable, Mapping
20+
21+
from monai.config.type_definitions import NdarrayOrTensor
22+
from monai.data.meta_tensor import MetaTensor
23+
from monai.transforms.inverse import InvertibleTransform
24+
from monai.transforms.transform import MapTransform
25+
from monai.utils.enums import PostFix, TransformBackends
26+
27+
__all__ = [
28+
"FromMetaTensord",
29+
"FromMetaTensorD",
30+
"FromMetaTensorDict",
31+
"ToMetaTensord",
32+
"ToMetaTensorD",
33+
"ToMetaTensorDict",
34+
]
35+
36+
37+
class FromMetaTensord(MapTransform, InvertibleTransform):
38+
"""
39+
Dictionary-based transform to convert MetaTensor to a dictionary.
40+
41+
If input is `{"a": MetaTensor, "b": MetaTensor}`, then output will
42+
have the form `{"a": torch.Tensor, "a_meta_dict": dict, "b": ...}`.
43+
"""
44+
45+
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
46+
47+
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
48+
d = dict(data)
49+
for key in self.key_iterator(d):
50+
self.push_transform(d, key)
51+
im: MetaTensor = d[key] # type: ignore
52+
d.update(im.as_dict(key))
53+
return d
54+
55+
def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
56+
d = deepcopy(dict(data))
57+
for key in self.key_iterator(d):
58+
# check transform
59+
_ = self.get_most_recent_transform(d, key)
60+
# do the inverse
61+
im, meta = d[key], d.pop(PostFix.meta(key), None)
62+
im = MetaTensor(im, meta=meta) # type: ignore
63+
d[key] = im
64+
# Remove the applied transform
65+
self.pop_transform(d, key)
66+
return d
67+
68+
69+
class ToMetaTensord(MapTransform, InvertibleTransform):
70+
"""
71+
Dictionary-based transform to convert a dictionary to MetaTensor.
72+
73+
If input is `{"a": torch.Tensor, "a_meta_dict": dict, "b": ...}`, then output will
74+
have the form `{"a": MetaTensor, "b": MetaTensor}`.
75+
"""
76+
77+
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
78+
79+
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
80+
d = dict(data)
81+
for key in self.key_iterator(d):
82+
self.push_transform(d, key)
83+
im, meta = d[key], d.pop(PostFix.meta(key), None)
84+
im = MetaTensor(im, meta=meta) # type: ignore
85+
d[key] = im
86+
return d
87+
88+
def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
89+
d = deepcopy(dict(data))
90+
for key in self.key_iterator(d):
91+
# check transform
92+
_ = self.get_most_recent_transform(d, key)
93+
# do the inverse
94+
im: MetaTensor = d[key] # type: ignore
95+
d.update(im.as_dict(key))
96+
# Remove the applied transform
97+
self.pop_transform(d, key)
98+
return d
99+
100+
101+
FromMetaTensorD = FromMetaTensorDict = FromMetaTensord
102+
ToMetaTensorD = ToMetaTensorDict = ToMetaTensord

monai/utils/enums.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -227,13 +227,13 @@ class ForwardMode(Enum):
227227
class TraceKeys:
228228
"""Extra meta data keys used for traceable transforms."""
229229

230-
CLASS_NAME = "class"
231-
ID = "id"
232-
ORIG_SIZE = "orig_size"
233-
EXTRA_INFO = "extra_info"
234-
DO_TRANSFORM = "do_transforms"
235-
KEY_SUFFIX = "_transforms"
236-
NONE = "none"
230+
CLASS_NAME: str = "class"
231+
ID: str = "id"
232+
ORIG_SIZE: str = "orig_size"
233+
EXTRA_INFO: str = "extra_info"
234+
DO_TRANSFORM: str = "do_transforms"
235+
KEY_SUFFIX: str = "_transforms"
236+
NONE: str = "none"
237237

238238

239239
@deprecated(since="0.8.0", msg_suffix="use monai.utils.enums.TraceKeys instead.")
@@ -287,6 +287,10 @@ def meta(key: Optional[str] = None):
287287
def orig_meta(key: Optional[str] = None):
288288
return PostFix._get_str(key, "orig_meta_dict")
289289

290+
@staticmethod
291+
def transforms(key: Optional[str] = None):
292+
return PostFix._get_str(key, TraceKeys.KEY_SUFFIX[1:])
293+
290294

291295
class TransformBackends(Enum):
292296
"""

tests/test_module_list.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def test_transform_api(self):
4040
to_exclude = {"MapTransform"} # except for these transforms
4141
to_exclude_docs = {"Decollate", "Ensemble", "Invert", "SaveClassification", "RandTorchVision"}
4242
to_exclude_docs.update({"DeleteItems", "SelectItems", "CopyItems", "ConcatItems"})
43+
to_exclude_docs.update({"ToMetaTensor", "FromMetaTensor"})
4344
xforms = {
4445
name: obj
4546
for name, obj in monai.transforms.__dict__.items()

tests/test_to_from_meta_tensord.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import random
13+
import string
14+
import unittest
15+
from copy import deepcopy
16+
from typing import Optional, Union
17+
18+
import torch
19+
from parameterized import parameterized
20+
21+
from monai.data.meta_tensor import MetaTensor
22+
from monai.transforms import FromMetaTensord, ToMetaTensord
23+
from monai.utils.enums import PostFix
24+
from monai.utils.module import get_torch_version_tuple
25+
from tests.utils import TEST_DEVICES, assert_allclose
26+
27+
PT_VER_MAJ, PT_VER_MIN = get_torch_version_tuple()
28+
29+
DTYPES = [[torch.float32], [torch.float64], [torch.float16], [torch.int64], [torch.int32]]
30+
TESTS = []
31+
for _device in TEST_DEVICES:
32+
for _dtype in DTYPES:
33+
TESTS.append((*_device, *_dtype))
34+
35+
36+
def rand_string(min_len=5, max_len=10):
37+
str_size = random.randint(min_len, max_len)
38+
chars = string.ascii_letters + string.punctuation
39+
return "".join(random.choice(chars) for _ in range(str_size))
40+
41+
42+
class TestToFromMetaTensord(unittest.TestCase):
43+
@staticmethod
44+
def get_im(shape=None, dtype=None, device=None):
45+
if shape is None:
46+
shape = shape = (1, 10, 8)
47+
affine = torch.randint(0, 10, (4, 4))
48+
meta = {"fname": rand_string()}
49+
t = torch.rand(shape)
50+
if dtype is not None:
51+
t = t.to(dtype)
52+
if device is not None:
53+
t = t.to(device)
54+
m = MetaTensor(t.clone(), affine, meta)
55+
return m
56+
57+
def check_ids(self, a, b, should_match):
58+
comp = self.assertEqual if should_match else self.assertNotEqual
59+
comp(id(a), id(b))
60+
61+
def check(
62+
self,
63+
out: torch.Tensor,
64+
orig: torch.Tensor,
65+
*,
66+
shape: bool = True,
67+
vals: bool = True,
68+
ids: bool = True,
69+
device: Optional[Union[str, torch.device]] = None,
70+
meta: bool = True,
71+
check_ids: bool = True,
72+
**kwargs,
73+
):
74+
if device is None:
75+
device = orig.device
76+
77+
# check the image
78+
self.assertIsInstance(out, type(orig))
79+
if shape:
80+
assert_allclose(torch.as_tensor(out.shape), torch.as_tensor(orig.shape))
81+
if vals:
82+
assert_allclose(out, orig, **kwargs)
83+
if check_ids:
84+
self.check_ids(out, orig, ids)
85+
self.assertTrue(str(device) in str(out.device))
86+
87+
# check meta and affine are equal and affine is on correct device
88+
if isinstance(orig, MetaTensor) and isinstance(out, MetaTensor) and meta:
89+
orig_meta_no_affine = deepcopy(orig.meta)
90+
del orig_meta_no_affine["affine"]
91+
out_meta_no_affine = deepcopy(out.meta)
92+
del out_meta_no_affine["affine"]
93+
self.assertEqual(orig_meta_no_affine, out_meta_no_affine)
94+
assert_allclose(out.affine, orig.affine)
95+
self.assertTrue(str(device) in str(out.affine.device))
96+
if check_ids:
97+
self.check_ids(out.affine, orig.affine, ids)
98+
self.check_ids(out.meta, orig.meta, ids)
99+
100+
@parameterized.expand(TESTS)
101+
def test_from_to_meta_tensord(self, device, dtype):
102+
m1 = self.get_im(device=device, dtype=dtype)
103+
m2 = self.get_im(device=device, dtype=dtype)
104+
m3 = self.get_im(device=device, dtype=dtype)
105+
d_metas = {"m1": m1, "m2": m2, "m3": m3}
106+
m1_meta = {k: v for k, v in m1.meta.items() if k != "affine"}
107+
m1_aff = m1.affine
108+
109+
# FROM -> forward
110+
t_from_meta = FromMetaTensord(["m1", "m2"])
111+
d_dict = t_from_meta(d_metas)
112+
113+
self.assertEqual(
114+
sorted(d_dict.keys()),
115+
[
116+
"m1",
117+
PostFix.meta("m1"),
118+
PostFix.transforms("m1"),
119+
"m2",
120+
PostFix.meta("m2"),
121+
PostFix.transforms("m2"),
122+
"m3",
123+
],
124+
)
125+
self.check(d_dict["m3"], m3, ids=True) # unchanged
126+
self.check(d_dict["m1"], m1.as_tensor(), ids=False)
127+
meta_out = {k: v for k, v in d_dict["m1_meta_dict"].items() if k != "affine"}
128+
aff_out = d_dict["m1_meta_dict"]["affine"]
129+
self.check(aff_out, m1_aff, ids=True)
130+
self.assertEqual(meta_out, m1_meta)
131+
132+
# FROM -> inverse
133+
d_meta_dict_meta = t_from_meta.inverse(d_dict)
134+
self.assertEqual(
135+
sorted(d_meta_dict_meta.keys()), ["m1", PostFix.transforms("m1"), "m2", PostFix.transforms("m2"), "m3"]
136+
)
137+
self.check(d_meta_dict_meta["m3"], m3, ids=False) # unchanged (except deep copy in inverse)
138+
self.check(d_meta_dict_meta["m1"], m1, ids=False)
139+
meta_out = {k: v for k, v in d_meta_dict_meta["m1"].meta.items() if k != "affine"}
140+
aff_out = d_meta_dict_meta["m1"].affine
141+
self.check(aff_out, m1_aff, ids=False)
142+
self.assertEqual(meta_out, m1_meta)
143+
144+
# TO -> Forward
145+
t_to_meta = ToMetaTensord(["m1", "m2"])
146+
del d_dict["m1_transforms"]
147+
del d_dict["m2_transforms"]
148+
d_dict_meta = t_to_meta(d_dict)
149+
self.assertEqual(
150+
sorted(d_dict_meta.keys()), ["m1", PostFix.transforms("m1"), "m2", PostFix.transforms("m2"), "m3"]
151+
)
152+
self.check(d_dict_meta["m3"], m3, ids=True) # unchanged (except deep copy in inverse)
153+
self.check(d_dict_meta["m1"], m1, ids=False)
154+
meta_out = {k: v for k, v in d_dict_meta["m1"].meta.items() if k != "affine"}
155+
aff_out = d_dict_meta["m1"].meta["affine"]
156+
self.check(aff_out, m1_aff, ids=False)
157+
self.assertEqual(meta_out, m1_meta)
158+
159+
# TO -> Inverse
160+
d_dict_meta_dict = t_to_meta.inverse(d_dict_meta)
161+
self.assertEqual(
162+
sorted(d_dict_meta_dict.keys()),
163+
[
164+
"m1",
165+
PostFix.meta("m1"),
166+
PostFix.transforms("m1"),
167+
"m2",
168+
PostFix.meta("m2"),
169+
PostFix.transforms("m2"),
170+
"m3",
171+
],
172+
)
173+
self.check(d_dict_meta_dict["m3"], m3.as_tensor(), ids=False) # unchanged (except deep copy in inverse)
174+
self.check(d_dict_meta_dict["m1"], m1.as_tensor(), ids=False)
175+
meta_out = {k: v for k, v in d_dict_meta_dict["m1_meta_dict"].items() if k != "affine"}
176+
aff_out = d_dict_meta_dict["m1_meta_dict"]["affine"]
177+
self.check(aff_out, m1_aff, ids=False)
178+
self.assertEqual(meta_out, m1_meta)
179+
180+
181+
if __name__ == "__main__":
182+
unittest.main()

0 commit comments

Comments
 (0)